making a VICReg

Background Link to heading

In Lex Fridmans interview of Yann Lecun, Yann Lecun says there are a few things LLMs fundamentally can’t do in their current state that a human can.

Some of thee include:

  • Planning
  • Logical reasoning

He is an advocate of self-supervised learning, and in the interview, talks about JEPA and how self-supervised learning could allow models to construct a world model by just observing the world.

alt text

The joint embedding cuts out the irrelevant details of a feed (like cutting out specific leaf positions in a dashcam video) and keeps the important information (like which way the road is turning in a dashcam video).

Once this embedding is trained, a predictor can be trained on top of it to predict the next frames embedding from the previous frames embedding.

This is essentially trying to predict the next frame in a given video, but cutting out irrelevant details by doing the predictions in embedding space.

alt text

Yann Lecun mentions VICReg as an example embedding network, so I looked up the paper and tried it out on CIFAR-10.

results Link to heading

The final network has the following architecture:

alt text

class VICReg(nn.Module):
    def __init__(self, embedding_dimention, representation_dimention):
        super(VICReg, self).__init__()
        self.features = models.vgg16().features
        for param in self.features.parameters():
            param.requires_grad = False
        
        self.embedding = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(128, 32),
            nn.ReLU(),
            torch.nn.BatchNorm1d(32),
            nn.Linear(32, embedding_dimention),
        )

        self.expand = nn.Sequential(
            nn.Linear(embedding_dimention, 80),
            torch.nn.BatchNorm1d(80),
            nn.ReLU(),
            nn.Linear(80, 80),
            torch.nn.BatchNorm1d(80),
            nn.ReLU(),
            nn.Linear(80, representation_dimention),
        )
    
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.embedding(x)
        x = self.expand(x)
        return x
    
    
model = VICReg(50, 100)

If you want to try this yourself, I used this tutorial.

The coolest thing about VICReg is the loss function. With just a MSE loss, all embdding would collapse to a single constant point (you get 0 loss if you map everything to a single point).

I omitted the pushing forces in my loss function and got a collapse of the 2-dimentional embedding space like this:

alt text alt text

There are contrastive and non-contrastive methods for preventing collapse.

Instead of using a contrastive method (having positive and negative examples to pull and push examples from eachother), VICReg has a loss function with a regularization term so only positive examples are needed.

The loss function is made of 3 parts…

V loss Link to heading

The V stands for variance. This looks at the variance of the embeddings over the batch. The more variance there is between embeddings in a batch, the lower the loss is. This pushes apart representations and prevents collapse.

The formula looks like this

alt text

For example, if batch size was 2 and we had the embeddings [1,2,3] and [1,2,3] for both samples, the variance loss would be high. If we had [0,0,0] and [1,1,1] for both samples, the variance loss would be low.

alt text

I loss Link to heading

The I stands for invariance. This is the normal MSE loss that pulls together similar samples.

alt text

C loss Link to heading

The C stands for covariance. This is not necessary but nice to have. This prevents dimentionality collapse. This means if we allow a dimention 3 embedding space, we encourage the model to use all 3 dimentions for embedding and discourage correlations between dimentions.

alt text

VICReg loss Link to heading

Combining it all together we get a pulling force from I loss, pushing force from V loss and C loss.

alt text

This is the code I used:

I_importance = 25
V_importance = 25
C_importance = 1

epsilon = 0.0001
relu = nn.ReLU()

def loss_var(image1):
    std_z_a = torch.sqrt(image1.var(dim=0) + epsilon)
    
    loss_v_a = torch.mean(relu(1 - std_z_a))
    return loss_v_a
mse = torch.nn.MSELoss()

def loss_cov(image1):
    N, D = image1.shape
    image1 = image1 - image1.mean(dim=0)
    cov_z_a = ((image1.T @ image1) / (N - 1)).square()  # DxD
    
    return (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
    

def VICReg_loss(image1, image2):
    S_error = mse(image1, image2)
    
    var_loss = loss_var(image1) + loss_var(image2)
    
    cov_loss = loss_cov(image1) + loss_cov(image2)
    
    return I_importance*S_error + C_importance*cov_loss + V_importance*var_loss 
    

The embedding space I got looks a bit funky but that is probably because I am visualizing a 50 dimentional embedding space in 2 dimentions with T-SNE.

alt text alt text

Loss averaged over 200 batches each:

alt text

The diagonal pattern is probably because the covariance error is too small, which encourages points to be along the diagonal instead of using the whole space.

I tried making the weights 1,1,1 and increased the embedding dimention to 50 and the points spread out a bit.

alt text