I've been trying to recreate the Dino V1 traning set up for a personal project. For which I've take majority of the code from this repo: https://github.com/facebookresearch/dino[dinov1 link]1
And rn I'm almost done with it except for one part in the main_dino.py file there is a function called train_one_epoch whereby in line 318 they have given:
teacher_output= teacher (images[:2]) # only the 2 global views pass through the teacher
Now I know how pytorch tensor indexing/slicing works. Hence, if images are a batch of images of a structure:
(batch size, num crops, c, h, w)
Prior to the call of train_one_epoch()
there was another modifiaction to the models, both the student
and teacher
models are wrapped with MultiCropWrapper
class. Just take a look at the class' docstring as follows:
class MultiCropWrapper(nn.Module):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
So this MultiCropWrapper class handles the forward passes, and it is also mentioned that it does several forward passes for different resolutions.