Captivating Captions: Generating Instagram-Worthy Captions with Fine-Tuned Microsoft-git-base

Mohammed Ashraf
4 min readJan 29, 2024

--

(Content credentials Generated with AI)

In the realm of Instagram, where every picture tells a story, the caption is the narrative that breathes life into your visuals. Our fine-tuned GIT model takes the essence of your images and weaves captivating, Insta-worthy captions that resonate with your audience. From sparking curiosity to evoking emotions, each caption is a masterpiece tailored to your unique content.

GIT: A Generative Image-to-text Transformer for Vision and Language’ by Wang et al., this transformative model redefines the synergy between visual and textual elements.

In my exploration of Instagram caption generation, I delved into experimenting with Blip and the ChatGPT API for building. Combining these tools yielded impressive results, efficiently crafting captivating captions for visual content. However, the journey of innovation is relentless. Now, let’s fine-tune the GIT model for even more personalized Instagram captions

The Dataset I used

Step 1: Setup and Install Necessary Libraries

This step installs the necessary libraries for working with Hugging Face models and datasets

! pip install git+https://github.com/huggingface/transformers.git accelerate datasets

Step 2: Load the Instagram Dataset

This step loads a pre-existing Instagram post captions dataset from the Hugging Face Hub. The dataset is split into a training set and a test set.

from datasets import load_dataset

dataset = load_dataset("mrSoul7766/instagram_post_captions", split='train[0:1000]')

In my case, I took a sample of 1000 images, which my free T4 machine can handle.

Step 3: Preprocess the Data

  • Loading the tokenizer associated with the model Microsoft/git-base.
  • Creating a custom ImageCaptioningDataset class that inherits from torch.utils.data.Dataset.
  • Defining the __getitem__() method to load and preprocess a single data item, including resizing the image, tokenizing the caption, and padding the sequences to a fixed length.
  • Creating a train_dataset object using the ImageCaptioningDataset class.
import torch

from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("microsoft/git-base")

class ImageCaptioningDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
item = self.dataset[idx]

encoding = self.processor(images=item["image"], text=item["caption"], padding="max_length", return_tensors="pt")

# remove batch dimension
encoding = {k:v.squeeze() for k,v in encoding.items()}

return encoding

train_dataset = ImageCaptioningDataset(dataset, processor)

Step 4: Create the DataLoader

Create a DataLoader object for the training dataset. The DataLoader helps in iterating through the dataset in batches during training.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

Step 5: Load the Pre-trained Git-Base Model

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

Step 6: Training the Model

  • Setting the model to training mode.
  • Iterating through the training dataset in batches.
  • For each batch, pass the input image pixel values and caption token IDs to the model.
  • Calculating the loss, which is the cross-entropy loss between the predicted and actual caption token IDs.
  • Backpropagating the loss to calculate gradients.
  • Updating the model’s parameters using an optimizer.
  • Printing the loss after each batch to monitor the training progress.
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(50):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device)

outputs = model(input_ids=input_ids,
pixel_values=pixel_values,
labels=input_ids)

loss = outputs.loss

print("Loss:", loss.item())

loss.backward()

optimizer.step()
optimizer.zero_grad()

Step 7: Image Captioning Inference

  • Preprocessing the image using the tokenizer.
  • Passing the preprocessed image to the model to generate a sequence of token IDs.
  • Decoding the generated token IDs back into a human-readable caption using the tokenizer.
  • Printing the generated caption.

Congratulations! You have successfully fine-tuned the git-base for Instagram caption generation

Conclusion

In conclusion, the successful fine-tuning of the git-base model marks a significant milestone in the quest for crafting Instagram-style captions on images. Through experimentation and dedication, we have harnessed the potential of the git-base model, elevating the art of caption generation to new heights.

--

--

Mohammed Ashraf

Dedicated GenAI Enthusiast crafting the future of AI with fervor and precision. Seamlessly blending cutting-edge algorithms with intuitive user experiences. 🌟