Vision-Language Models from Scratch: Introduction

Alex Moore
5 min readApr 26, 2024

Part 0: Introduction

This project will track my implementation of multimodal language models as a simple composition of models from scratch. As always, code for this blog will be up-to-date on my Github.

Vision-language models (multimodal language models, MLLM, MMLLM) are a highly motivating family of deep learning models. Multimodality for language models as a field of research is interested in using large language model backbones as a starting point for introducing data of various kinds. Flamingo has a highly motivating example of a use case of multimodality for visual question answering:

Flamingo authors give an excellent example motivating the use of multimodal language models. Generalist LLMs with promising zero-shot performance can be modified to also understand visual data.

In this example, the Flamingo authors motivate a use case of VLMs. In the first row, interwoven text and images are supplied as two examples followed by a query. The model relates the visual data to the caption (“This is a chincilla…”), then follows the pattern to effectively caption the uncaptioned flamingo image for an output.

In the fourth row, the authors supply a different task: given two examples of written arithmetic in image form, the model should read the input image, parse the visual symbols into math, and evaluate the arithmetic. Imagine training these models from scratch! For the first row alone you would need to curate some captioned images of animals and train a model to perform animal-image captioning. For the fourth row, you would need some dataset of arithmetic images to train and validate on. But multimodal LLMs provide a single “intelligent” backbone on which visual data can be injected.

Part 1: Motivation and Goals

The simple composition of a vision encoder and LLM, along with a simple lightweight projection model

Code implementation supplied here.

Perhaps the most exciting component of multimodal vision-language models is that the formulation is relatively simple. LLaVA is a popular cornerstone of the open multimodal LLM research frontier, and I highly recommend the research for their discussion in particular of formulating datasets for the multimodal finetuning stage. They rely on a visionless GPT4 for synthetic dataset generation to bootstrap captioned images based on relative object positions and description.

Bunny is another similar library which emphasizes the composition of components as a simple pipeline to introduce multimodality to LLMs.

CodeLLama is code-only LLM, but references prefix-middle-suffix work which may be highly beneficial for bootstrapping data and model generalizability to different text-image compositions.

One of my standing goals for a plausibility study would be to incorporate a domain-expert language model (think chemistry-corpus-tuned gpt3.5), and a domain encoder (think molecular graph representation model) as a curiosity. Would this perform well? Is it useful to have some kind of chemistry knowledge as an LLM incorporated into some kind of chemistry-comprehending model? Would zero-shot be beneficial for some kind of transfer or general chemistry understanding tasks? I use chemistry as an example but imagine any data modality inside or outside the physical sciences: spectra data, graph data, social network data, fluid simulation data…

Part 2: Current System: Technology and Tools

Currently we rely on google/vit-base-patch16–224 as a vision backbone and Phi-3 mini 4k instruct as a language background. Performance likely increases with larger LLM size, and unique tasks may benefit from specific image encoders (object segmentation pretrained image encoders vs. image captioning image encoders, as an example). We use the forward with embedding vectors override on the forward function to supply a custom-format prompt.

I have also added a LoRA component to the llm and vit. the llm lora is hidden 64 and the vit is hidden 8.

Current data is COCO-2017 trainval. We use a custom formatting which modifies the image-caption pair into a image-question-answer context for end-to-end finetuning.

Training and validating with a lightning lit_module wrapper around the VLM (vlm holds .image_encoder, .image_processor, .llm, .llm_processor, and .projector)

Summary of existing algorithm (full code on GitHub) for a generic multimodal language model. This is a simple example showing multimodal data (image) cast to token sequence, formatted with a prompt, and Casual language model training on that embedding sequence:

def train_step(image, query, caption):
# Embed image and text to tokens
image = vlm.image_encoder(vlm.image_preprocessor(image))
image_tokens = vlm.image_projection(image)

query = vlm.llm.embed_tokens(vlm.llm_tokenizer(query))
caption = vlm.llm.embed_tokens(vlm.llm_tokenizer(caption))

# Format - concat along sequence dimension
input = torch.cat((image_tokens, query, caption), dim = 1)

# Forward - predict next token
output_logits = vlm.llm(embeds = input)
return output_logits[:, -1].softmax(dim = 1)

def forward(self, batch):
"""
Given a batch, format the input, do the forward, get the logits, return logits, loss
Predicts the next token given an image, random substring of caption
"""
device = batch['image'].device
tokenized_image = self.image_forward(batch['image'])

# Tokenize string
int_captions = torch.LongTensor(self.language_tokenizer(batch['caption'])['input_ids']).to(device)

predict_at_index = random.randint(1, int_captions.shape[1] - 2)
caption_prefix = self.embed_ints(int_captions[:, :predict_at_index])
caption_target = int_captions[:, predict_at_index]

self.start_vec = self.start_vec.to(device)
self.end_vec = self.end_vec.to(device)
self.query_vec = self.query_vec.to(device)

# Structure token sequence
llm_input = torch.cat((self.start_vec, tokenized_image, self.end_vec, self.query_vec, caption_prefix), dim = 1)#.permute(0,2,1)

# Forward with frozen llm
output = self.language_model.forward(inputs_embeds = llm_input)
logits = output.logits

#print(logits.shape)
last_logit = logits[:, -1, :]

loss = self.loss_function(last_logit, caption_target)

return logits, loss

def generate(self, batch, max_new_tokens):
device = batch['image'].device
#self.language_model.assisted_decoding

# idx is (B, T) array of indices in the current context
tokenized_image = self.image_forward(batch['image'])

# get initial prompt for llm
# Tokenize string
int_captions = torch.LongTensor(self.language_tokenizer(batch['caption'])['input_ids']).to(device)

predict_at_index = 0
caption_prefix = self.embed_ints(int_captions[:, :predict_at_index])

self.start_vec = self.start_vec.to(device)
self.end_vec = self.end_vec.to(device)
self.query_vec = self.query_vec.to(device)

# Structure token sequence
llm_input = torch.cat((self.start_vec, tokenized_image, self.end_vec, self.query_vec, caption_prefix), dim = 1)#.permute(0,2,1)

logit_outputs = []
for _ in range(max_new_tokens):
# Forward on prompt
outputs = self.language_model.forward(inputs_embeds = llm_input)
logit_output = outputs.logits[:, -1, :]
logit_outputs.append(logit_output)

#print(llm_vector_prompt.shape, logit_output.shape)
# Add EMBEDDED output to current sequence
int_output = logit_output.argmax(dim = 1)
#print('int output', int_output)
new_vec = self.embed_ints(int_output).unsqueeze(0)
llm_input = torch.cat((llm_input, new_vec), dim = 1)

logit_outputs = torch.stack(logit_outputs, dim = 1)
#print('generate constructed outputs', logit_outputs.shape)

# Logits to ints
int_outputs = logit_outputs.argmax(dim = 2)
str_outputs = self.language_tokenizer.decode(int_outputs[0])
print('Caption: [', batch['caption'], ']')
print('Str out: [', str_outputs, ']')
return str_outputs

--

--