Captivating Captions: Generating Instagram-Worthy Captions with Fine-Tuned Microsoft-git-base
In the realm of Instagram, where every picture tells a story, the caption is the narrative that breathes life into your visuals. Our fine-tuned GIT model takes the essence of your images and weaves captivating, Insta-worthy captions that resonate with your audience. From sparking curiosity to evoking emotions, each caption is a masterpiece tailored to your unique content.
GIT: A Generative Image-to-text Transformer for Vision and Language’ by Wang et al., this transformative model redefines the synergy between visual and textual elements.
In my exploration of Instagram caption generation, I delved into experimenting with Blip and the ChatGPT API for building. Combining these tools yielded impressive results, efficiently crafting captivating captions for visual content. However, the journey of innovation is relentless. Now, let’s fine-tune the GIT model for even more personalized Instagram captions
The Dataset I used
Step 1: Setup and Install Necessary Libraries
This step installs the necessary libraries for working with Hugging Face models and datasets
! pip install git+https://github.com/huggingface/transformers.git accelerate datasets
Step 2: Load the Instagram Dataset
This step loads a pre-existing Instagram post captions dataset from the Hugging Face Hub. The dataset is split into a training set and a test set.
from datasets import load_dataset
dataset = load_dataset("mrSoul7766/instagram_post_captions", split='train[0:1000]')
In my case, I took a sample of 1000 images, which my free T4 machine can handle.
Step 3: Preprocess the Data
- Loading the tokenizer associated with the model Microsoft/git-base.
- Creating a custom
ImageCaptioningDataset
class that inherits fromtorch.utils.data.Dataset
. - Defining the
__getitem__()
method to load and preprocess a single data item, including resizing the image, tokenizing the caption, and padding the sequences to a fixed length. - Creating a
train_dataset
object using theImageCaptioningDataset
class.
import torch
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("microsoft/git-base")
class ImageCaptioningDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
encoding = self.processor(images=item["image"], text=item["caption"], padding="max_length", return_tensors="pt")
# remove batch dimension
encoding = {k:v.squeeze() for k,v in encoding.items()}
return encoding
train_dataset = ImageCaptioningDataset(dataset, processor)
Step 4: Create the DataLoader
Create a DataLoader object for the training dataset. The DataLoader helps in iterating through the dataset in batches during training.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)
Step 5: Load the Pre-trained Git-Base Model
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
Step 6: Training the Model
- Setting the model to training mode.
- Iterating through the training dataset in batches.
- For each batch, pass the input image pixel values and caption token IDs to the model.
- Calculating the loss, which is the cross-entropy loss between the predicted and actual caption token IDs.
- Backpropagating the loss to calculate gradients.
- Updating the model’s parameters using an optimizer.
- Printing the loss after each batch to monitor the training progress.
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()
for epoch in range(50):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device)
outputs = model(input_ids=input_ids,
pixel_values=pixel_values,
labels=input_ids)
loss = outputs.loss
print("Loss:", loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
Step 7: Image Captioning Inference
- Preprocessing the image using the tokenizer.
- Passing the preprocessed image to the model to generate a sequence of token IDs.
- Decoding the generated token IDs back into a human-readable caption using the tokenizer.
- Printing the generated caption.
Congratulations! You have successfully fine-tuned the git-base for Instagram caption generation
Conclusion
In conclusion, the successful fine-tuning of the git-base model marks a significant milestone in the quest for crafting Instagram-style captions on images. Through experimentation and dedication, we have harnessed the potential of the git-base model, elevating the art of caption generation to new heights.