WTF is going on with fast.ai?
This article would probably be better titled “Deconstructing fast.ai” or “Fast.ai from first principles”, and will center on building a model for text generation with PyTorch. I apologize for the rather click-bait headline.
I’m assuming that readers will be familiar with fast.ai, and have probably taken their Deep Learning online course. I did the first iteration of the course, when it used Keras. The second iteration, released to the public only last week, switched to using PyTorch as a framework. I took that version of the course in person as the University of San Francisco’s Certificate in Deep Learning (which lessons are recorded and compiled into the MOOC.)
The first, Keras, course was awesome — super intuitive, practical, and inspiring. Version 2, PyTorch, was more cutting edge, and contained much of the same theoretical material as before — except it tried to abstract away some of the pain points, and flexibility, of PyTorch with Prof Jeremy Howard’s own fast.ai library. For me, it took away much of the visibility and immediacy of PyTorch and instead provided a variety of magical functions.
That’s in-line with the ethos of the course, which is about building performant AI applications quickly, but did not sit well with me. The fast.ai library does not have the elegance of Keras; it is poorly documented; the code tries too hard to be clever rather than readable; and besides: PyTorch doesn’t really need an abstraction layer. BLAS, wrapped with CUDA, wrapped with PyTorch, is in my opinion a great and adequate stack for building AI. If I abstract that any further, I might as well use an AWS machine learning API.
With that goal in mind, let’s take a look at cutting out the fast.ai cruft from their recurrent neural network example. You can find their original notebook on their Github, here, and my simplified version in the repo here.
fast.ai’s notebook generates some Nietzsche-like text, and it begins by getting some input data — Mr Existential Angst’s complete works — with:
Wrapping a basic urllib method seems rather unnecessary to me. They do it primarily to put a tqdm progress bar on the download, but I think that’s just arcana that obscures the simple purpose of the code. Instead, let’s KISS, and keep things light with a bit of Shakespeare.
The rest of fast.ai’s pre-processing just converts that text to numbers that can be input to a neural network (and provides for the reverse transformation). It’s not particularly interesting, and looks like this:
Windowing The Text
The theory behind this network is that 3 preceding characters can predict the 4th character. fast.ai, quite sensibly, creates a training set of X and y as follows (slightly simplified by me):
Here, they are using np.stack as a quick way to turn a list into a numpy array, but let’s briefly examine that function as it crops up later too. It takes a keyword argument, axis, and stacks columns along that axis. Stacking along the first axis makes no difference, but if you stack along the second (axis=1):
Building The Model
The actual models in fast.ai are really well explained and are not the focus of this article, but here’s their simple code for taking those 3 input characters and predicting the next. One simplification I made: fast.ai’s code uses a method called
V as short hand for
torch.autograd.Variable— useful if you’re writing a ton of models to teach people, but opaque for the learner, so I switched it back.
Note in particular how the model takes three separate tensors as input, instead of a single input tensor of higher rank. Prof Jeremy’s original notebook does move to that as it progresses from the basic model above to one with an LSTM layer. But for now, splitting the inputs is easier to comprehend so let’s stick with it, even if it adds complexity to feeding training data to the model as we will see.
In other places the fast.ai code uses
T as shorthand for the (wrapped) function
torch.Tensor (in this case our tensor must index into the embeddings matrix with integers, so the particular tensor type here is
LongTensor). That’s especially confusing when T is also numpy shorthand for transpose. I’ve reverted to the long form.
Training The Model
This is my particular bugbear with the fast.ai code. Take a look at the original:
ColumnarModelData was glossed over on the course and indeed trying to dig into it with:
Plumbing the source code reveals that this object inherits from or uses the following:
BatchSampler, RandomSampler & SequentialSampler (from Torch itself)
DataLoader is another undocumented class while
PassthruDataset simply says:
An abstract class representing a Dataset. All other datasets should subclass it.
Coincidentally this is the same exact docstring as for the
Dataset class. Well, we can at least see that whatever
ColumnarModelData is, it returns an object that has a property which is an iterator. In fact, each iteration yields a batch of inputs, a simple idea when shown visually:
I replaced this with my own iterator:
Then you have all the * weirdness that fast.ai uses together with the iterator. I’d never come across that syntax before, and I specifically remember in class Prof Jeremy saying, “If you don’t know what this means, look it up”. Turns out it’s called the unpack operator and does this:
X = [1,2,3,4]
*xs, y = X # xs -> [1,2,3] and y -> 4
fn(*xs) -> fn(1, 2, 3)
Armed with this knowledge, we can train the model in the regular PyTorch way without needing any wrapper classes; especially not the fit magic method that fast.ai’s notebook uses:
This actually trains faster than fast.ai’s model, while achieving similar accuracy, for what that’s worth (which is not much: remember this was their rubbishy, basic version). And while outcomes are not really the purpose of this article, it’d be disappointing to end without giving a sample output after training on the Bard for a couple of epochs:
Well, we retained the good stuff from fast.ai’s lesson 6, without needing to use their library. Of course, if you haven’t done the course yet have a look; Google allows you to run Jupyter notebooks with PyTorch on GPU for free at their colaboratory; and you can also find my complete pared-down notebook on Github here.