from napari_convpaint.conv_paint_model import ConvpaintModel
import numpy as np
import matplotlib.pyplot as plt
import skimage
c:\Users\roman\miniforge3\envs\cp-env02\Lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Using Convpaint programmatically (API)#

Loading a saved model and using it in a loop#

In one example workflow you might have interactively trained the model in the Napari GUI, and now want to use it programmatically in batch processing. Here we load a model trained on sample data and apply it to an image:

cpm = ConvpaintModel(model_path="../sample_data/lily_VGG16.pkl")

img = skimage.data.lily() # Load example image
img = np.moveaxis(img, -1, 0) # Move channel axis to first position, this is the convention throughout Convpaint

segmentation = cpm.segment(img)

For a simple illustration we extract the first 3 of the 4 channels of the image and display them as RGB. On the right, we show the predicted classes.

# Show the image and the annotations next to each other
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Plot the image
axes[0].imshow(skimage.data.lily()[:,:,:3], cmap='gray')
axes[0].set_title('Image')

cmap = plt.get_cmap('tab20', 3)
cmap.set_under('white')

# Plot the prediction
axes[1].imshow(segmentation, cmap=cmap,interpolation='nearest')#,vmin=1,vmax=3)
axes[1].set_title('Prediction')

#disable x and y ticks
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0..4095].
../_images/b286f58c2550fea359f062348344159f0fd1d2cb562ece42621a10defcce9062.png

Creating and training a new model using the API#

We can also create new models from scratch programatically. Like this, we can easily compare the performance of different models.

First let’s look at all the available feature extractors:

ConvpaintModel.get_fe_models_types()
{'vgg16': napari_convpaint.conv_paint_nnlayers.Hookmodel,
 'efficient_netb0': napari_convpaint.conv_paint_nnlayers.Hookmodel,
 'convnext': napari_convpaint.conv_paint_nnlayers.Hookmodel,
 'gaussian_features': napari_convpaint.conv_paint_gaussian.GaussianFeatures,
 'dinov2_vits14_reg': napari_convpaint.conv_paint_dino.DinoFeatures,
 'combo_dino_vgg': napari_convpaint.conv_paint_combo_fe.ComboFeatures,
 'combo_dino_gauss': napari_convpaint.conv_paint_combo_fe.ComboFeatures,
 'combo_dino_ilastik': napari_convpaint.conv_paint_combo_fe.ComboFeatures,
 'vit_small_patch14_reg4_dinov2': napari_convpaint.conv_paint_dino_jafar.DinoJafarFeatures,
 'cellpose_backbone': napari_convpaint.conv_paint_cellpose.CellposeFeatures,
 'ilastik_2d': napari_convpaint.conv_paint_ilastik.IlastikFeatures}
Tip: It's easy to implement your own feature extractor! Have a look at the file conv_paint_gaussian.py to see a minimal example. We've also provided a template file conv_paint_fe_template.py with instructions - you can copy and modify it to create your own feature extractor.

When using CNN such as VGG16 as feature extractor, we need to supply the layers we want to use.

We can print out the selectable layers like this:

cpm2 = ConvpaintModel()
cpm2.get_fe_layer_keys()
['features.0 Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.5 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.7 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.10 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.12 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.14 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.17 Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.19 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.21 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.24 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.26 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))',
 'features.28 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))']

By default, we’ve created a model using VGG16 with just the first layer for feature extraction:

cpm2.get_param("fe_layers")
['features.0 Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))']
Hint: The layers from VGG16 are filtered to only show layers of type Conv2D. Convpaint adds hooks at the selected layers, i.e. we capture their output. The flow through the network is interrupted at the last selected layer to speed up processing. Each hooked layer returns a certain number of outputs, which are then all concatenated into a single set of features.

Lets create a new VGG16 model using the first 2 layers. We store all the necessary settings in the ConvpaintModel object, which is usually created from the GUI toggles. Note that the layers can either be specified by names or as indices (among the available layers).

cpm3 = ConvpaintModel(fe_name="vgg16", fe_layers=[0, 1])
cpm3.get_param("fe_layers")
[0, 1]

Besides the layers for CNNs, there are several other options that can be set in the ConvpaintModel object:

  • fe_scalings specifies the levels of downscaling to use for feature extraction (1 is the original size)

  • fe_order specifies the spline order used to upscale small feature maps (either from the downscaling, or from aggregation in the neural network).

  • with fe_use_min_features=True, only the n-first features of each output layer are selected, n being the number of features of the layer which outputs the least of them. This can help balance the weight of different layers.

  • normalize will normalize the image so that it matches more closely the input expected by the pre-trained network.

  • image_downsample allows to use a smaller version of the image as input. Note that this doesn’t change the size of the predicted output, as it gets rescaled to the original size in the end.

For a comprehensive description of all parameters and options, please refer to the separate page.

Let’s create a custom model. As you see below, except for the fe_name and layers (and gpu), all parameters can easily be adjusted after initialization - either one at a time, or multiples in one call:

cpm4 = ConvpaintModel(fe_name="vgg16", fe_layers=[0, 1])
cpm4.set_param("fe_scalings", [1, 2])
cpm4.set_params(fe_order=0, fe_use_min_features=False, image_downsample=3)

Train and test the model on artificial data:

# create a noisy image
img = np.random.rand(200,200)
factor = 0.3
img[:50,:] = img[:50,:]*factor
img[80:100,:] = img[80:100,:]*factor

# draw small rectangle as annotations 
annotations = np.zeros((200,200))
annotations[10:30,10:30] = 1      #class 1
annotations[170:190,170:190] = 2  #class 2

# train the classifier to predict the annotations from the features
cpm4.train(img, annotations)

# predict the annotations from the image
segmentation = cpm4.segment(img)
C:\Users\roman\Documents\Convpaint\hinderling-cp\napari-convpaint\src\napari_convpaint\conv_paint_model.py:1506: UserWarning: Annotations for image 0 are not of type int32. Converting to int32.
  warnings.warn(f'Annotations for image {i} are not of type int32. Converting to int32.')
0:	learn: 0.5503165	total: 145ms	remaining: 14.4s
1:	learn: 0.4330621	total: 147ms	remaining: 7.22s
2:	learn: 0.3625376	total: 150ms	remaining: 4.85s
3:	learn: 0.2958804	total: 153ms	remaining: 3.67s
4:	learn: 0.2345301	total: 155ms	remaining: 2.95s
5:	learn: 0.1966157	total: 158ms	remaining: 2.47s
6:	learn: 0.1628511	total: 160ms	remaining: 2.12s
7:	learn: 0.1422693	total: 162ms	remaining: 1.86s
8:	learn: 0.1223218	total: 164ms	remaining: 1.65s
9:	learn: 0.1072466	total: 166ms	remaining: 1.49s
10:	learn: 0.0950612	total: 168ms	remaining: 1.36s
11:	learn: 0.0840313	total: 170ms	remaining: 1.24s
12:	learn: 0.0748385	total: 172ms	remaining: 1.15s
13:	learn: 0.0678239	total: 174ms	remaining: 1.07s
14:	learn: 0.0612400	total: 176ms	remaining: 997ms
15:	learn: 0.0541271	total: 178ms	remaining: 934ms
16:	learn: 0.0505085	total: 180ms	remaining: 878ms
17:	learn: 0.0462722	total: 182ms	remaining: 828ms
18:	learn: 0.0438944	total: 184ms	remaining: 783ms
19:	learn: 0.0391880	total: 186ms	remaining: 743ms
20:	learn: 0.0368704	total: 188ms	remaining: 706ms
21:	learn: 0.0339036	total: 190ms	remaining: 674ms
22:	learn: 0.0319389	total: 192ms	remaining: 644ms
23:	learn: 0.0299572	total: 194ms	remaining: 615ms
24:	learn: 0.0275961	total: 196ms	remaining: 589ms
25:	learn: 0.0265592	total: 198ms	remaining: 565ms
26:	learn: 0.0252884	total: 200ms	remaining: 542ms
27:	learn: 0.0237065	total: 202ms	remaining: 520ms
28:	learn: 0.0224039	total: 205ms	remaining: 502ms
29:	learn: 0.0212422	total: 207ms	remaining: 484ms
30:	learn: 0.0201063	total: 209ms	remaining: 466ms
31:	learn: 0.0194672	total: 211ms	remaining: 449ms
32:	learn: 0.0187031	total: 213ms	remaining: 433ms
33:	learn: 0.0178967	total: 215ms	remaining: 417ms
34:	learn: 0.0169173	total: 217ms	remaining: 403ms
35:	learn: 0.0160058	total: 219ms	remaining: 389ms
36:	learn: 0.0152015	total: 221ms	remaining: 376ms
37:	learn: 0.0145148	total: 223ms	remaining: 364ms
38:	learn: 0.0139900	total: 225ms	remaining: 352ms
39:	learn: 0.0134231	total: 227ms	remaining: 340ms
40:	learn: 0.0129785	total: 229ms	remaining: 329ms
41:	learn: 0.0125471	total: 231ms	remaining: 319ms
42:	learn: 0.0120575	total: 233ms	remaining: 308ms
43:	learn: 0.0116022	total: 234ms	remaining: 298ms
44:	learn: 0.0113342	total: 236ms	remaining: 289ms
45:	learn: 0.0110280	total: 238ms	remaining: 280ms
46:	learn: 0.0107504	total: 241ms	remaining: 271ms
47:	learn: 0.0103991	total: 243ms	remaining: 263ms
48:	learn: 0.0101068	total: 245ms	remaining: 255ms
49:	learn: 0.0097844	total: 246ms	remaining: 246ms
50:	learn: 0.0095382	total: 248ms	remaining: 239ms
51:	learn: 0.0093908	total: 250ms	remaining: 231ms
52:	learn: 0.0092236	total: 252ms	remaining: 224ms
53:	learn: 0.0089554	total: 254ms	remaining: 217ms
54:	learn: 0.0088031	total: 256ms	remaining: 210ms
55:	learn: 0.0086493	total: 258ms	remaining: 203ms
56:	learn: 0.0084800	total: 260ms	remaining: 196ms
57:	learn: 0.0082808	total: 262ms	remaining: 190ms
58:	learn: 0.0081644	total: 264ms	remaining: 184ms
59:	learn: 0.0080434	total: 266ms	remaining: 178ms
60:	learn: 0.0078585	total: 268ms	remaining: 172ms
61:	learn: 0.0075358	total: 270ms	remaining: 166ms
62:	learn: 0.0074239	total: 273ms	remaining: 160ms
63:	learn: 0.0073114	total: 275ms	remaining: 155ms
64:	learn: 0.0071696	total: 277ms	remaining: 149ms
65:	learn: 0.0070035	total: 279ms	remaining: 144ms
66:	learn: 0.0069156	total: 281ms	remaining: 138ms
67:	learn: 0.0068254	total: 283ms	remaining: 133ms
68:	learn: 0.0066572	total: 285ms	remaining: 128ms
69:	learn: 0.0065722	total: 287ms	remaining: 123ms
70:	learn: 0.0064895	total: 289ms	remaining: 118ms
71:	learn: 0.0063920	total: 292ms	remaining: 113ms
72:	learn: 0.0062870	total: 294ms	remaining: 109ms
73:	learn: 0.0061250	total: 295ms	remaining: 104ms
74:	learn: 0.0060021	total: 297ms	remaining: 99.1ms
75:	learn: 0.0058918	total: 299ms	remaining: 94.6ms
76:	learn: 0.0058215	total: 301ms	remaining: 90.1ms
77:	learn: 0.0056995	total: 304ms	remaining: 85.7ms
78:	learn: 0.0056376	total: 306ms	remaining: 81.4ms
79:	learn: 0.0055287	total: 308ms	remaining: 77ms
80:	learn: 0.0054683	total: 310ms	remaining: 72.7ms
81:	learn: 0.0053668	total: 312ms	remaining: 68.5ms
82:	learn: 0.0052915	total: 314ms	remaining: 64.3ms
83:	learn: 0.0051776	total: 316ms	remaining: 60.2ms
84:	learn: 0.0051290	total: 318ms	remaining: 56.1ms
85:	learn: 0.0050417	total: 319ms	remaining: 52ms
86:	learn: 0.0049660	total: 322ms	remaining: 48ms
87:	learn: 0.0049046	total: 324ms	remaining: 44.1ms
88:	learn: 0.0048686	total: 326ms	remaining: 40.2ms
89:	learn: 0.0047724	total: 327ms	remaining: 36.4ms
90:	learn: 0.0047078	total: 329ms	remaining: 32.6ms
91:	learn: 0.0046449	total: 331ms	remaining: 28.8ms
92:	learn: 0.0045718	total: 333ms	remaining: 25ms
93:	learn: 0.0045107	total: 334ms	remaining: 21.3ms
94:	learn: 0.0044516	total: 337ms	remaining: 17.7ms
95:	learn: 0.0043861	total: 338ms	remaining: 14.1ms
96:	learn: 0.0043576	total: 340ms	remaining: 10.5ms
97:	learn: 0.0042957	total: 342ms	remaining: 6.98ms
98:	learn: 0.0042351	total: 344ms	remaining: 3.47ms
99:	learn: 0.0041767	total: 345ms	remaining: 0us
# show the image and the annotations
fig, axes = plt.subplots(1, 3, figsize=(12, 6))

# Plot the image
axes[0].imshow(img, cmap='gray')
axes[0].set_title('Image')

cmap = plt.get_cmap('tab20')
cmap.set_under('white')
# Plot the annotations
axes[1].imshow(annotations, cmap=cmap,interpolation='nearest',vmin=1,vmax=3)
axes[1].set_title('Annotations')

# Plot the prediction
axes[2].imshow(segmentation, cmap=cmap,interpolation='nearest',vmin=1,vmax=3)
axes[2].set_title('Prediction')

#disable x and y ticks
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
../_images/42be6b88ffa6b1e9b55965de21edf0dcb7d7a0563df8da751a3423a50d37a3a0.png

The selected output layers have 64 and 64 output features and we’re using three scalings, so in total we have (64+64)*3 = 384 features. With our ConvpaintModel, we can also extract those features to display and analyze them further:

features = cpm4.get_feature_image(img)
print(f"Number of features: {features.shape}")
plt.imshow(features[0], cmap='gray')
plt.title('First feature map')
plt.show()
Number of features: (256, 200, 200)
../_images/c01e1d5d964f1751497cd90ceff94663eeeb80016a82ad62762fd6d42c4b518a.png

Running Convpaint in a loop for batch processing#

num_images = 10 # Number of images to generate

imgs = []
segmentations = []

for i in range(num_images):
    # Create a noisy image
    img = np.random.rand(200, 200)
    start_row = np.random.randint(0, 150)
    end_row = start_row + np.random.randint(20, 50)
    img[start_row:end_row, :] = img[start_row:end_row, :] * 0.5
    
    # Make prediction
    segmentation = cpm4.segment(img)
    
    imgs.append(img)
    segmentations.append(segmentation)
# Create a figure with a grid of subplots
fig, axs = plt.subplots(2, num_images, figsize=(12, 3))

for i in range(num_images):
    # Plot sample image
    axs[0, i].imshow(imgs[i], cmap='gray')
    axs[0, i].set_title(f'{i+1}')
    axs[0, i].set_xticks([])
    axs[0, i].set_yticks([])


    
    # Plot prediction
    axs[1, i].imshow(segmentations[i], cmap=cmap, interpolation='nearest', vmin=1, vmax=3)
    axs[1, i].set_title(f'')
    axs[1, i].set_xticks([])
    axs[1, i].set_yticks([])


# Add y labels
axs[0, 0].set_ylabel('Image')
axs[1, 0].set_ylabel('Prediction')

plt.tight_layout()
plt.show()
../_images/6ab6c6bd8a68f4ba4d9a64d373674ae3c0324a8a07b2fcbb801d27b2d76c6c19.png
Note: In Pertz Lab we train a model using the GUI on data live streamed from the microscope at the beginning of an experiment. The trained model is then used programmatically for smart microscopy approaches, for example enabling optogenetic stimulation of subcellular areas, or automated selection of cells that fit a certain shape criteria.

Creating a model using DINOv2 as feature extractor#

For the ViT based DINOv2 model we’re not selecting the layers, and instead just extract all patch based features.

Note that here, we are using another option to initialize a ConvpaintModel. Using an alias to create a pre-defined model is in fact the simplest way to get started with Convpaint. For details, refer to the Feature Extractor page.

Importantly, here we are handling images with an additional dimension. Hence, we need to tell the model what that dimension represents: a third “spatial” dimension (in particular z or time), or a “channel” dimension (for example, RGB color channels). Here, we set multi_channel_img to True in accordance with the RGB color channels of the image.

cpm_dino = ConvpaintModel(alias="dino")
cpm_dino.set_param("multi_channel_img", True)

# create new dataset
img = skimage.data.stereo_motorcycle()
train_img = np.moveaxis(img[0], -1, 0)
pred_img = np.moveaxis(img[1], -1, 0)
annotations = np.zeros(img[0][:,:,0].shape)
#foreground [y,x]
annotations[50:100,50:100] = 1
annotations[450:500,500:550] = 1
#background [x,y]
annotations[200:250,400:450] = 2
annotations[300:350,200:400] = 2

# Train the model
cpm_dino.train(train_img, annotations)

# Use it to segment the image
segmentation = cpm_dino.segment(image=pred_img)
C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main\dinov2\layers\swiglu_ffn.py:51: UserWarning: xFormers is not available (SwiGLU)
  warnings.warn("xFormers is not available (SwiGLU)")
C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main\dinov2\layers\attention.py:33: UserWarning: xFormers is not available (Attention)
  warnings.warn("xFormers is not available (Attention)")
C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main\dinov2\layers\block.py:40: UserWarning: xFormers is not available (Block)
  warnings.warn("xFormers is not available (Block)")
C:\Users\roman\Documents\Convpaint\hinderling-cp\napari-convpaint\src\napari_convpaint\conv_paint_model.py:1506: UserWarning: Annotations for image 0 are not of type int32. Converting to int32.
  warnings.warn(f'Annotations for image {i} are not of type int32. Converting to int32.')
0:	learn: 0.3768599	total: 23.4ms	remaining: 2.31s
1:	learn: 0.2002017	total: 32.7ms	remaining: 1.6s
2:	learn: 0.1085167	total: 43.8ms	remaining: 1.42s
3:	learn: 0.0631080	total: 53.6ms	remaining: 1.29s
4:	learn: 0.0358521	total: 63.7ms	remaining: 1.21s
5:	learn: 0.0219060	total: 73ms	remaining: 1.14s
6:	learn: 0.0140858	total: 82.7ms	remaining: 1.1s
7:	learn: 0.0092302	total: 92.3ms	remaining: 1.06s
8:	learn: 0.0062938	total: 101ms	remaining: 1.03s
9:	learn: 0.0045885	total: 111ms	remaining: 1s
10:	learn: 0.0032304	total: 120ms	remaining: 973ms
11:	learn: 0.0024106	total: 130ms	remaining: 954ms
12:	learn: 0.0018580	total: 139ms	remaining: 933ms
13:	learn: 0.0014287	total: 149ms	remaining: 913ms
14:	learn: 0.0011273	total: 158ms	remaining: 895ms
15:	learn: 0.0009283	total: 167ms	remaining: 877ms
16:	learn: 0.0007666	total: 176ms	remaining: 859ms
17:	learn: 0.0006563	total: 184ms	remaining: 840ms
18:	learn: 0.0005687	total: 194ms	remaining: 826ms
19:	learn: 0.0004960	total: 202ms	remaining: 809ms
20:	learn: 0.0004467	total: 212ms	remaining: 797ms
21:	learn: 0.0004095	total: 220ms	remaining: 780ms
22:	learn: 0.0003607	total: 230ms	remaining: 770ms
23:	learn: 0.0003323	total: 239ms	remaining: 755ms
24:	learn: 0.0003091	total: 248ms	remaining: 743ms
25:	learn: 0.0002905	total: 256ms	remaining: 730ms
26:	learn: 0.0002741	total: 265ms	remaining: 716ms
27:	learn: 0.0002594	total: 273ms	remaining: 703ms
28:	learn: 0.0002490	total: 282ms	remaining: 690ms
29:	learn: 0.0002490	total: 290ms	remaining: 677ms
30:	learn: 0.0002490	total: 299ms	remaining: 665ms
31:	learn: 0.0002490	total: 307ms	remaining: 653ms
32:	learn: 0.0002227	total: 316ms	remaining: 642ms
33:	learn: 0.0002151	total: 325ms	remaining: 631ms
34:	learn: 0.0001998	total: 334ms	remaining: 620ms
35:	learn: 0.0001998	total: 342ms	remaining: 608ms
36:	learn: 0.0001998	total: 350ms	remaining: 596ms
37:	learn: 0.0001998	total: 360ms	remaining: 587ms
38:	learn: 0.0001997	total: 369ms	remaining: 578ms
39:	learn: 0.0001998	total: 379ms	remaining: 568ms
40:	learn: 0.0001998	total: 388ms	remaining: 558ms
41:	learn: 0.0001969	total: 396ms	remaining: 547ms
42:	learn: 0.0001933	total: 405ms	remaining: 537ms
43:	learn: 0.0001933	total: 413ms	remaining: 526ms
44:	learn: 0.0001932	total: 422ms	remaining: 516ms
45:	learn: 0.0001895	total: 430ms	remaining: 505ms
46:	learn: 0.0001894	total: 439ms	remaining: 495ms
47:	learn: 0.0001862	total: 448ms	remaining: 485ms
48:	learn: 0.0001861	total: 457ms	remaining: 476ms
49:	learn: 0.0001861	total: 466ms	remaining: 466ms
50:	learn: 0.0001861	total: 475ms	remaining: 456ms
51:	learn: 0.0001860	total: 483ms	remaining: 446ms
52:	learn: 0.0001860	total: 492ms	remaining: 436ms
53:	learn: 0.0001860	total: 501ms	remaining: 426ms
54:	learn: 0.0001860	total: 510ms	remaining: 417ms
55:	learn: 0.0001860	total: 518ms	remaining: 407ms
56:	learn: 0.0001859	total: 528ms	remaining: 398ms
57:	learn: 0.0001859	total: 536ms	remaining: 388ms
58:	learn: 0.0001860	total: 545ms	remaining: 379ms
59:	learn: 0.0001859	total: 554ms	remaining: 370ms
60:	learn: 0.0001859	total: 563ms	remaining: 360ms
61:	learn: 0.0001825	total: 572ms	remaining: 351ms
62:	learn: 0.0001791	total: 581ms	remaining: 341ms
63:	learn: 0.0001791	total: 590ms	remaining: 332ms
64:	learn: 0.0001763	total: 598ms	remaining: 322ms
65:	learn: 0.0001763	total: 607ms	remaining: 313ms
66:	learn: 0.0001730	total: 616ms	remaining: 303ms
67:	learn: 0.0001730	total: 625ms	remaining: 294ms
68:	learn: 0.0001730	total: 633ms	remaining: 285ms
69:	learn: 0.0001730	total: 642ms	remaining: 275ms
70:	learn: 0.0001729	total: 651ms	remaining: 266ms
71:	learn: 0.0001729	total: 660ms	remaining: 256ms
72:	learn: 0.0001679	total: 669ms	remaining: 247ms
73:	learn: 0.0001679	total: 677ms	remaining: 238ms
74:	learn: 0.0001680	total: 687ms	remaining: 229ms
75:	learn: 0.0001680	total: 695ms	remaining: 220ms
76:	learn: 0.0001680	total: 704ms	remaining: 210ms
77:	learn: 0.0001680	total: 713ms	remaining: 201ms
78:	learn: 0.0001680	total: 722ms	remaining: 192ms
79:	learn: 0.0001680	total: 730ms	remaining: 183ms
80:	learn: 0.0001679	total: 739ms	remaining: 173ms
81:	learn: 0.0001680	total: 747ms	remaining: 164ms
82:	learn: 0.0001620	total: 756ms	remaining: 155ms
83:	learn: 0.0001620	total: 764ms	remaining: 146ms
84:	learn: 0.0001620	total: 773ms	remaining: 136ms
85:	learn: 0.0001620	total: 782ms	remaining: 127ms
86:	learn: 0.0001619	total: 791ms	remaining: 118ms
87:	learn: 0.0001619	total: 800ms	remaining: 109ms
88:	learn: 0.0001619	total: 809ms	remaining: 100ms
89:	learn: 0.0001619	total: 818ms	remaining: 90.8ms
90:	learn: 0.0001620	total: 826ms	remaining: 81.7ms
91:	learn: 0.0001619	total: 836ms	remaining: 72.7ms
92:	learn: 0.0001620	total: 844ms	remaining: 63.5ms
93:	learn: 0.0001620	total: 853ms	remaining: 54.5ms
94:	learn: 0.0001619	total: 862ms	remaining: 45.3ms
95:	learn: 0.0001620	total: 871ms	remaining: 36.3ms
96:	learn: 0.0001620	total: 880ms	remaining: 27.2ms
97:	learn: 0.0001619	total: 889ms	remaining: 18.1ms
98:	learn: 0.0001619	total: 898ms	remaining: 9.07ms
99:	learn: 0.0001619	total: 906ms	remaining: 0us
# show the image and the annotations
fig, axes = plt.subplots(1, 4, figsize=(12, 8))

# Plot the image
axes[0].imshow(img[0], cmap='gray')
axes[0].set_title(f'Stereo image 0 (training)')

cmap = plt.get_cmap('tab20')
cmap.set_under('white')

# Plot the annotations
axes[1].imshow(annotations, cmap=cmap, interpolation='nearest', vmin=1, vmax=3)
axes[1].set_title('Annotations')

# Plot the image to predict
axes[2].imshow(img[1], cmap='gray')
axes[2].set_title(f'Stereo image 1 (to predict)')


# Plot the prediction
axes[3].imshow(segmentation, cmap=cmap, interpolation='nearest', vmin=1, vmax=3)
axes[3].set_title('Prediction')

#disable x and y ticks
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
../_images/10723c98d62a726960eb1d75c4f4bef3e410610f3df7230ab6299196bb8fd17c.png

Visualizing the extracted DINOv2 features#

feature_image = cpm_dino.get_feature_image(train_img)
num_features = feature_image.shape[0]
print(f"Extracted {num_features} features.")

grid_size = 5
random_features = np.random.choice(num_features, grid_size*grid_size)
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(feature_image[random_features[i]], cmap='viridis')
    ax.set_xticks([])
    ax.set_yticks([])
    
# Set w/h distance between subplots to 0
plt.subplots_adjust(wspace=0.1, hspace=0.1)
Extracted 384 features.
../_images/725f30c8701e26d42d220fc613224a870edd44c7f9561194a6b38d8713e49705.png
Tip: It can be useful to visualize the extracted features for troubleshooting and getting an intuition for the level of detail that the model considers. In this example, DINOv2 extracts high-level semantic features like tires, floor, bike body or engine.

Combining features from multiple models#

The strengths of different feature extractors can be combined by concatenating their outputs. For example, the good spatial resolution of VGG16 can be combined with the rich semantic features of DINOv2. This leads to very successful segmentations on some datasets.

We are providing a selection of pre-defined combo feature extractors. You can use these out of the box or as a starting point for your own custom configurations.

cpm_combo = ConvpaintModel(fe_name="combo_dino_vgg", multi_channel_img=True)

train_img = np.moveaxis(skimage.data.stereo_motorcycle()[0],-1,0)
pred_img = np.moveaxis(skimage.data.stereo_motorcycle()[1],-1,0) 

# Train on the first image
cpm_combo.train(train_img, annotations)

# Predict on the second
segmentation = cpm_combo.segment(pred_img)
C:\Users\roman\Documents\Convpaint\hinderling-cp\napari-convpaint\src\napari_convpaint\conv_paint_model.py:1506: UserWarning: Annotations for image 0 are not of type int32. Converting to int32.
  warnings.warn(f'Annotations for image {i} are not of type int32. Converting to int32.')
0:	learn: 0.3776665	total: 16.6ms	remaining: 1.64s
1:	learn: 0.2040959	total: 30.7ms	remaining: 1.5s
2:	learn: 0.1167316	total: 44.4ms	remaining: 1.44s
3:	learn: 0.0662353	total: 57.5ms	remaining: 1.38s
4:	learn: 0.0377773	total: 70.2ms	remaining: 1.33s
5:	learn: 0.0241646	total: 83.7ms	remaining: 1.31s
6:	learn: 0.0152460	total: 97.1ms	remaining: 1.29s
7:	learn: 0.0103850	total: 110ms	remaining: 1.27s
8:	learn: 0.0070520	total: 124ms	remaining: 1.25s
9:	learn: 0.0048217	total: 136ms	remaining: 1.23s
10:	learn: 0.0034012	total: 149ms	remaining: 1.2s
11:	learn: 0.0025735	total: 162ms	remaining: 1.19s
12:	learn: 0.0019695	total: 174ms	remaining: 1.17s
13:	learn: 0.0015157	total: 188ms	remaining: 1.16s
14:	learn: 0.0012130	total: 201ms	remaining: 1.14s
15:	learn: 0.0009815	total: 214ms	remaining: 1.12s
16:	learn: 0.0007977	total: 228ms	remaining: 1.11s
17:	learn: 0.0006455	total: 241ms	remaining: 1.1s
18:	learn: 0.0005496	total: 254ms	remaining: 1.08s
19:	learn: 0.0004818	total: 267ms	remaining: 1.07s
20:	learn: 0.0004188	total: 280ms	remaining: 1.05s
21:	learn: 0.0003605	total: 292ms	remaining: 1.03s
22:	learn: 0.0003103	total: 304ms	remaining: 1.02s
23:	learn: 0.0002718	total: 317ms	remaining: 1s
24:	learn: 0.0002468	total: 330ms	remaining: 990ms
25:	learn: 0.0002377	total: 342ms	remaining: 973ms
26:	learn: 0.0002377	total: 354ms	remaining: 956ms
27:	learn: 0.0002185	total: 365ms	remaining: 939ms
28:	learn: 0.0002025	total: 378ms	remaining: 926ms
29:	learn: 0.0001869	total: 391ms	remaining: 912ms
30:	learn: 0.0001869	total: 402ms	remaining: 895ms
31:	learn: 0.0001831	total: 414ms	remaining: 879ms
32:	learn: 0.0001797	total: 426ms	remaining: 865ms
33:	learn: 0.0001797	total: 437ms	remaining: 849ms
34:	learn: 0.0001797	total: 448ms	remaining: 832ms
35:	learn: 0.0001797	total: 460ms	remaining: 818ms
36:	learn: 0.0001797	total: 472ms	remaining: 803ms
37:	learn: 0.0001796	total: 483ms	remaining: 788ms
38:	learn: 0.0001768	total: 495ms	remaining: 774ms
39:	learn: 0.0001735	total: 507ms	remaining: 760ms
40:	learn: 0.0001735	total: 518ms	remaining: 745ms
41:	learn: 0.0001735	total: 529ms	remaining: 731ms
42:	learn: 0.0001735	total: 541ms	remaining: 717ms
43:	learn: 0.0001708	total: 552ms	remaining: 703ms
44:	learn: 0.0001676	total: 564ms	remaining: 689ms
45:	learn: 0.0001676	total: 576ms	remaining: 676ms
46:	learn: 0.0001676	total: 588ms	remaining: 663ms
47:	learn: 0.0001676	total: 601ms	remaining: 651ms
48:	learn: 0.0001676	total: 613ms	remaining: 638ms
49:	learn: 0.0001676	total: 625ms	remaining: 625ms
50:	learn: 0.0001583	total: 637ms	remaining: 612ms
51:	learn: 0.0001531	total: 648ms	remaining: 598ms
52:	learn: 0.0001531	total: 659ms	remaining: 585ms
53:	learn: 0.0001531	total: 671ms	remaining: 572ms
54:	learn: 0.0001531	total: 683ms	remaining: 559ms
55:	learn: 0.0001531	total: 694ms	remaining: 545ms
56:	learn: 0.0001485	total: 706ms	remaining: 533ms
57:	learn: 0.0001438	total: 718ms	remaining: 520ms
58:	learn: 0.0001438	total: 729ms	remaining: 507ms
59:	learn: 0.0001438	total: 741ms	remaining: 494ms
60:	learn: 0.0001438	total: 753ms	remaining: 482ms
61:	learn: 0.0001438	total: 765ms	remaining: 469ms
62:	learn: 0.0001438	total: 777ms	remaining: 456ms
63:	learn: 0.0001438	total: 789ms	remaining: 444ms
64:	learn: 0.0001438	total: 800ms	remaining: 431ms
65:	learn: 0.0001438	total: 812ms	remaining: 418ms
66:	learn: 0.0001438	total: 823ms	remaining: 405ms
67:	learn: 0.0001438	total: 835ms	remaining: 393ms
68:	learn: 0.0001438	total: 847ms	remaining: 381ms
69:	learn: 0.0001438	total: 859ms	remaining: 368ms
70:	learn: 0.0001438	total: 872ms	remaining: 356ms
71:	learn: 0.0001438	total: 884ms	remaining: 344ms
72:	learn: 0.0001392	total: 895ms	remaining: 331ms
73:	learn: 0.0001392	total: 907ms	remaining: 319ms
74:	learn: 0.0001392	total: 920ms	remaining: 307ms
75:	learn: 0.0001392	total: 931ms	remaining: 294ms
76:	learn: 0.0001392	total: 942ms	remaining: 282ms
77:	learn: 0.0001392	total: 955ms	remaining: 269ms
78:	learn: 0.0001392	total: 968ms	remaining: 257ms
79:	learn: 0.0001392	total: 981ms	remaining: 245ms
80:	learn: 0.0001392	total: 992ms	remaining: 233ms
81:	learn: 0.0001392	total: 1.01s	remaining: 221ms
82:	learn: 0.0001349	total: 1.02s	remaining: 209ms
83:	learn: 0.0001349	total: 1.03s	remaining: 197ms
84:	learn: 0.0001349	total: 1.05s	remaining: 185ms
85:	learn: 0.0001348	total: 1.06s	remaining: 173ms
86:	learn: 0.0001348	total: 1.08s	remaining: 161ms
87:	learn: 0.0001348	total: 1.09s	remaining: 149ms
88:	learn: 0.0001348	total: 1.1s	remaining: 136ms
89:	learn: 0.0001348	total: 1.12s	remaining: 124ms
90:	learn: 0.0001348	total: 1.13s	remaining: 112ms
91:	learn: 0.0001348	total: 1.15s	remaining: 99.7ms
92:	learn: 0.0001348	total: 1.16s	remaining: 87.3ms
93:	learn: 0.0001348	total: 1.17s	remaining: 74.9ms
94:	learn: 0.0001348	total: 1.18s	remaining: 62.4ms
95:	learn: 0.0001348	total: 1.2s	remaining: 50ms
96:	learn: 0.0001348	total: 1.22s	remaining: 37.8ms
97:	learn: 0.0001348	total: 1.24s	remaining: 25.3ms
98:	learn: 0.0001348	total: 1.25s	remaining: 12.6ms
99:	learn: 0.0001348	total: 1.27s	remaining: 0us
# show the image and the annotations
fig, axes = plt.subplots(1, 4, figsize=(12, 6))

# Plot the image
axes[0].imshow(img[0], cmap='gray')

cmap = plt.get_cmap('tab20')
cmap.set_under('white')

# Plot the annotations
axes[1].imshow(annotations, cmap=cmap, interpolation='nearest', vmin=1, vmax=3)

# Plot the image to predict
axes[2].imshow(img[1], cmap='gray')

# Plot the prediction
axes[3].imshow(segmentation, cmap=cmap, interpolation='nearest', vmin=1, vmax=3)

# Disable x and y ticks
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
../_images/b7137696b75cbc93b5d5900141b81015a19fd505948fe906ce29ed1f0e1ac682.png

Visual comparison of features extracted by DINOv2 vs. VGG16#

cpm_vgg = ConvpaintModel(fe_name="vgg16", multi_channel_img=True)
features_vgg = cpm_vgg.get_feature_image(train_img)

cpm_dino = ConvpaintModel(fe_name="dinov2_vits14_reg", multi_channel_img=True)
features_dino = cpm_dino.get_feature_image(train_img)
fig, axes = plt.subplots(1, 3, figsize=(12, 12))

# Plot img, and a random feature each from dinov2 and vgg16
axes[1].imshow(features_dino[22,:,:], cmap='viridis')
axes[1].set_title('DINOV2 feature')
axes[2].imshow(features_vgg[22,:,:], cmap='viridis')
axes[2].set_title('VGG16 feature')
axes[0].imshow(img[0])
axes[0].set_title('Original Image')
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

plt.subplots_adjust(wspace=0.1, hspace=0.1)
../_images/bfdaf54b5aa3f801146ef79b55ee011b08b13baa386cdee18e9f24d06bd0a5de.png