I-JEPA: An improved approach towards self-supervised image-based learning

Vedant Palit
5 min readJun 27, 2023

--

This article is aimed at discussing the methodology and contributions of a CVPR’23 paper I read very recently and found really interesting. The core idea of the paper involves the utilisation of a Joint-Embedding Architecture towards predicting various representations of target blocks within the same image.

If you wish to read the paper, which was a contribution of the Facebook AI Research, this is the arXiv link (2301.08243).

So as discussed before, Image-based Joint-Embedding Predictive Architecture or I-JEPA (in short) proposes learning high-level semantic representations of images without relying on human-generated image augmentations such as rotating, scaling, cropping or generative methods which learn how to predict non-existent portions of an image. These methodologies lack adaptability towards tasks of different natures, and while Invariance-based pre-training methods (human-generated augmentations) often cover higher-level semantic representations, they lack robustness. On the other hand Generative methods, while they require lesser contextual knowledge and expand beyond image modality, lack semantic-level representations and underperform the prior methods in a variety of tasks.

The solution to the existing problem of heavy reliance on the two extreme ends of pre-training methods is the proposal of a new architecture, which not only is aimed at improving the semantic level of self-supervised representations but also does so without the inclusion of excess prior knowledge into the model.

Fig 1: (a) Joint-Embedding Architectures, align embedding outputs from x and y-encoders for compatible inputs whereas they misalign outputs for incompatible inputs (b) Generative Architectures, reconstruct a signal y based off of inputs x alongside hidden conditioning variables z (c) JEPA predicts a signal y, based off of a predictor conditioned on latent variables z as well the input x

The Preliminary Architectures

The Joint Embedding Architecture is based on the theory of entropy/energy-based models, which involves the assignment of a higher energy value to dissimilar inputs, and a lower energy value to similar inputs, thus making the collective objective of the system — lowering of energy.

To achieve this purpose of JEAs, hand-crafted augmentations become necessary for the generation of the concept of compatible and incompatible inputs. Additionally, to prevent the recurring issue of representation collapse wherein the energy profile flattens out (causing constant output generation for every input), contrastive losses which explicitly push apart dissimilar embeddings or information redundancy minimisation and clustering-based methods.

Generative Architectures are particularly designed to learn how to reconstruct the signal y, from a compatible masked signal x which is a copy of the signal y (signal y and x referred from the diagram 1b). The conditioning hidden variables z are then set as position and mask tokens, which specify to the decoder the patches to reconstruct. Hence it can be understood that GAs implement learning in the input space, as they reconstruct inputs themselves.

JEPA - Functioning and Training

Although there is a resemblance between the Joint Embedding Predictive Architecture and Generative Architecture, JEPA implements loss learning in the embedding space — The major idea being that it aims at predicting the embedding of a signal y from a compatible signal x, utilising a predictor network conditioned on latent variables z (refer diagram 1c).

Hence the main objective of I-JEPA is — given a single context block, predict the embedding representations of other target blocks where such embeddings are calculated on an encoder network.

Fig 2: A schematic Representation of the I-JEPA

The architecture of I-JEPA involves Vision Transformer Architectures for the context encoder, target encoder as well as predictor network.

Fig 3: Vision Transformer Architectures involve the generation of patch embeddings from image inputs, which resemble embeddings generate from textual inputs in language transformers, which are then fed into neural networks for downstream tasks

Now moving towards the functioning of the architecture, taking reference from Fig 2, the context encoder network takes as input, a single image context block (analogous to the input x, in Fig 1c). Using this context block, the predictor network utilises conditioned position tokens (analogous to the z variable in generative architectures — highlighted in red, blue and yellow in Fig 2), and makes predictions on representations of the target blocks at the specific positions. These target representation predictions become the parallels to the ground truth outputs of the target encoder networks from the actual input image.

Fig 4: A simplified view of how the loss function implementation works

The targets are obtained by basically sampling out some M samples from the target encoder output i.e. the target representations (these patches can be overlapping).

Fig 5: The context block is chosen by removing the overlapping portions of the sampled target blocks. In the experiment, M (number of target blocks)= 4 having a scale ratio in (0.15,0.2) and aspect ratio in (0.75,1.5)

The context block is independently sampled from the image, and due to this independence in the sampling process from target blocks, to eliminate triviality in predictions, the overlap of the targets and context has to be removed. Finally, the masked context block is fed into the context encoder followed by the predictor network to regenerate the embeddings.

Ultimately, to train the network an L2 (mean square error loss) is applied to the predicted and ground truth target representations, over M samples:

Mean Squared Error Loss

The final parameter updation in the context encoder and predictor networks occurs through gradient optimisation, while in the target encoder, the updation occurs via an exponential moving average of the context encoder parameters, essential for JEAs with ViTs.

Visualisations, Performance and Conclusion

Fig 6: The first column represents the original input image, and the rest of the columns are Generative Model decodings of the average pooled outputs of the Image-based Joint Embedded Predictive Architecture

I-JEPA requires lesser computing as compared to pixel reconstruction methods (Generative Architectures) and I-BOT (Hand-crafted Augmentations) due to faster convergence, which also allows scaling up of data for improved transfer learning performances due to diversity.

I-JEPA outperforms generative methods and performs as well as invariance based methods in Image Classification Tasks
I-JEPA outperforms invariance-based methods in low-level tasks and performs nearly as good as MAE(generative method) on the same tasks

This article was aimed at understanding and dissecting the Image-based Joint Embedding Predictive Architecture for self-supervised learning, in a shorter and a little simpler way. I personally found the paper quite interesting, although generative architectures are progressing quite faster in this domain, this architecture does work to fill in possible gaps.

All the diagrams, barring Fig 3 and Fig 4 have been directly taken from the paper. Fig 4 was made by me, for a simpler view of the functioning of the architecture.

Fig 3 is sourced from here — link

Thank you :)

--

--