Creating a Text Prediction Dataloader
When I was working on my project, I needed to train a transformer on an unsupervised task. In this case the task was text prediction. For this task I needed to break an input sequence into a source sequence and a target sequence, with the target sequence being shifted by one token.
Example sequence: The cat likes to sit in the sun on warm days
Source: The cat likes to sit in the sun on warm
Target: cat likes to sit in the sun on warm days
I used pytorch-lightning which required me to use a data loader. I spent quite a bit of time looking for an example of how to load text into a dataloader to get batched source and target sequences. Perhaps my googling skills need work, but I couldn’t find anything, so I build a dataloader that does this! Hopefully this is useful for you.
A link to a notebook with this code and example are available here.