I-JEPA: An improved approach towards self-supervised image-based learning
This article is aimed at discussing the methodology and contributions of a CVPR’23 paper I read very recently and found really interesting. The core idea of the paper involves the utilisation of a Joint-Embedding Architecture towards predicting various representations of target blocks within the same image.
If you wish to read the paper, which was a contribution of the Facebook AI Research, this is the arXiv link (2301.08243).
So as discussed before, Image-based Joint-Embedding Predictive Architecture or I-JEPA (in short) proposes learning high-level semantic representations of images without relying on human-generated image augmentations such as rotating, scaling, cropping or generative methods which learn how to predict non-existent portions of an image. These methodologies lack adaptability towards tasks of different natures, and while Invariance-based pre-training methods (human-generated augmentations) often cover higher-level semantic representations, they lack robustness. On the other hand Generative methods, while they require lesser contextual knowledge and expand beyond image modality, lack semantic-level representations and underperform the prior methods in a variety of tasks.
The solution to the existing problem of heavy reliance on the two extreme ends of pre-training methods is the proposal of a new architecture, which not only is aimed at improving the semantic level of self-supervised representations but also does so without the inclusion of excess prior knowledge into the model.
The Preliminary Architectures
The Joint Embedding Architecture is based on the theory of entropy/energy-based models, which involves the assignment of a higher energy value to dissimilar inputs, and a lower energy value to similar inputs, thus making the collective objective of the system — lowering of energy.
To achieve this purpose of JEAs, hand-crafted augmentations become necessary for the generation of the concept of compatible and incompatible inputs. Additionally, to prevent the recurring issue of representation collapse wherein the energy profile flattens out (causing constant output generation for every input), contrastive losses which explicitly push apart dissimilar embeddings or information redundancy minimisation and clustering-based methods.
Generative Architectures are particularly designed to learn how to reconstruct the signal y, from a compatible masked signal x which is a copy of the signal y (signal y and x referred from the diagram 1b). The conditioning hidden variables z are then set as position and mask tokens, which specify to the decoder the patches to reconstruct. Hence it can be understood that GAs implement learning in the input space, as they reconstruct inputs themselves.
JEPA - Functioning and Training
Although there is a resemblance between the Joint Embedding Predictive Architecture and Generative Architecture, JEPA implements loss learning in the embedding space — The major idea being that it aims at predicting the embedding of a signal y from a compatible signal x, utilising a predictor network conditioned on latent variables z (refer diagram 1c).
Hence the main objective of I-JEPA is — given a single context block, predict the embedding representations of other target blocks where such embeddings are calculated on an encoder network.
The architecture of I-JEPA involves Vision Transformer Architectures for the context encoder, target encoder as well as predictor network.
Now moving towards the functioning of the architecture, taking reference from Fig 2, the context encoder network takes as input, a single image context block (analogous to the input x, in Fig 1c). Using this context block, the predictor network utilises conditioned position tokens (analogous to the z variable in generative architectures — highlighted in red, blue and yellow in Fig 2), and makes predictions on representations of the target blocks at the specific positions. These target representation predictions become the parallels to the ground truth outputs of the target encoder networks from the actual input image.
The targets are obtained by basically sampling out some M samples from the target encoder output i.e. the target representations (these patches can be overlapping).
The context block is independently sampled from the image, and due to this independence in the sampling process from target blocks, to eliminate triviality in predictions, the overlap of the targets and context has to be removed. Finally, the masked context block is fed into the context encoder followed by the predictor network to regenerate the embeddings.
Ultimately, to train the network an L2 (mean square error loss) is applied to the predicted and ground truth target representations, over M samples:
The final parameter updation in the context encoder and predictor networks occurs through gradient optimisation, while in the target encoder, the updation occurs via an exponential moving average of the context encoder parameters, essential for JEAs with ViTs.
Visualisations, Performance and Conclusion
I-JEPA requires lesser computing as compared to pixel reconstruction methods (Generative Architectures) and I-BOT (Hand-crafted Augmentations) due to faster convergence, which also allows scaling up of data for improved transfer learning performances due to diversity.
This article was aimed at understanding and dissecting the Image-based Joint Embedding Predictive Architecture for self-supervised learning, in a shorter and a little simpler way. I personally found the paper quite interesting, although generative architectures are progressing quite faster in this domain, this architecture does work to fill in possible gaps.
All the diagrams, barring Fig 3 and Fig 4 have been directly taken from the paper. Fig 4 was made by me, for a simpler view of the functioning of the architecture.
Fig 3 is sourced from here — link
Thank you :)