When gradients just wont flow

Niloy Purkait
7 min readAug 13, 2020

--

https://www.fosslinux.com/34360/google-announces-200-open-source-mentors-for-the-2020-gsoc-event.htm

The adversarial training setup

Greetings all! Here comes another update on my GSoC project, concerning the development of an RDF-to-Text system, using generative adversarial networks. The idea of the project is simple: we use a Transformer based generator to generate a natural language text from a given RDF triple input. These generations are shown to yet another transformer, this time acting as a discriminator. The discriminator learns to discriminate between generated text and real text , for a corresponding RDF triple. Below are links to the previous articles documenting my progress in this project:

Link to previous articles and resources:

Previously on my GSoC Project:

In my previous article, I motivated the need for the discriminator to receive a context, in terms of the corresponding RDF triple, while classifying a given text. After all, it would not make sense to arbitrarily show the discriminator two perfect sentences from the English language, and expect it to be able to tell which one better fits a given RDF triple instance. The discriminator needs to see the input triple in some way, and my solution was to concatenate the input triple sequence with real and fake text, and then train the discriminator to distinguish between these two sequences. It all seemed well in theory, and initial experiments on pre-training the discriminator provided promising results. However this concatenation operation being done during training seemed intuitively messy, so I decided to refine the approach slightly. Instead of concatenating input and target sequences, I decided to simply show them as separate inputs to the discriminator. Hence, now our discriminator model has 2 input layers, where each input layer is connected to its own embedding layer and transformer block, and the output of these blocks are concatenated and passed to the dense classification layer. This approach avoids performing any operations during training that may break the computation graph of TensorFlow, or so I thought at least…

While reading this article, you may at times wonder why the elaborate explanation concerning specific parts of the code. Well, it is essential that you are up to date on these finer details, before I can elaborate on a very persistent and confusing problem that it runs into, when we actually try to execute it. But all will be clear in due time. For now, lets start with a quick recap of some inner workings of TensorFlow, and all within.

Recap on TensorFlow essentials

While the Keras API is essentially a higher level API that allows constructing and training neural nets with ease, some use cases (like ours) requires us to maintain finer control over the training step, which the lower level API of TensorFlow provides. Normally, the Keras API keeps track of a model’s gradients and applies them under the hood, TensorFlow actually allows hands on manipulation of these ‘under-the-hood’ operations. This can be done through TensorFlow’s tf.GradientTape API; which lets us automatically calculate the gradient of a computation with respect to some inputs, usually tf.Variables. TensorFlow keeps track of relevant operations executed inside the context of a tf.GradientTape, onto a “tape” object. TensorFlow then uses that tape to compute the gradients of a “recorded” computation using reverse mode differentiation. Thus, we can use tf.GradientTape to essentially keep track of the computations being done in two different networks, and use the output of one of these networks (i.e. the discriminator) to update the weights of another network (i.e. the generator). For a simpler example of an adversarial training setup coded with tf.GradientTape, have a look at this image generation use case here.

Lets see the training algorithm!

Now that we are up to date on how TensorFlow computes gradients, we can dive deeper down the rabbit hole. Armed with a functioning generator and a discriminator, I coded in an adversarial training loop. Here’s a simplistic overview of the algorithm:

Requires:

Simplified Pseudo-code

For each instance in dateset:

— Get predicted text from generator, for given RDF input

— Update discriminator using real and generated RDF-text sequences

— Use discriminator’s predictions of generated text to update generator’s weights

It seems quite neat when put like that, however to express the same idea in TensorFlow, we well see how it gets a tad more complex. Before we dive into the specifics of the training loop here, we must observe a small parentheses on the loss functions used to train these networks.

What about the loss functions?

As one would expect, each network in this adversarial setup has its own loss function. The Discriminator’s loss function quantifies its ability to distinguish real sequences from fakes. This is done by comparing the discriminator’s predictions on real sequences to an array of 1s, and the discriminator’s predictions on fake (i.e. generated) sequences to an array of 0s. The loss for both the real and the fake sequences are computed via binary cross entropy, and the results are summed up to obtain the overall loss for the Discriminator.

Discriminator’s loss function

Conversely the Generator’s loss quantifies its ability to trick the discriminator. In simpler terms, if the generator is doing well, the discriminator will classify generated sequences as real (or 1). Thus we compute the generator’s loss by comparing the discriminator’s decisions on the generated sequences to an array of 1s.

Generator’s loss function

And that about covers it for the loss functions! If you would like to see the entire script for the adversarial setup, click here. Now back to the training step.

Adversarial Training. source: https://xkcd.com/303/

Deeper dive into the training loop

The loop is coded using 2 gradient tapes, one for the generator and the other for the discriminator. We start by making the generator generate predictions on a batch of RDF triples, and feed the discriminator these predictions, as well as the real target sequences. As you may recall, the discriminator sees each training instance as 2 inputs — one being the input triple sequence and the other being the corresponding text, be it generated or real. From these RDF and Text input sequences, it learns to tell whether a given instance is fake or real. Thus, once we have the predictions from the generator for a given batch, we show them to the discriminator along with the input RDF batch. We also show it the real text with their corresponding RDF triples. Then, we use the discriminator’s predictions to calculate two separate losses (on real and fake instances) and update the weights of the discriminator with the combined loss on both real and fake batches. Next, to update the generator, we use the discriminator’s predictions on the fake instances, and compare it to the ideal case where each of them would be predicted as real (array of 1s), since this essentially means that our generator has fooled the discriminator into thinking that each generated output is a real sequence. More specifically, we calculate the loss using the generator’s loss function, by providing it the discriminator’s predictions on the generator, and taking the cross entropy loss between an array of 1s and the discriminator’s predictions. And here’s where it all breaks down!

Want the whole notebook? Click here!

While the networks train perfectly well during pre-training, when we initiate this-here adversarial training step above, we run into a pretty big problem. It looks something like this:

No gradients provided error while training the generator

What’s going on here?

From initial probings into the issue, it seemed that the generator’s gradients were not being computed, and hence the optimizer gets the default array of ‘None’ values, from the gen_tape.gradient() method. Interestingly, the discriminator’s gradient tape does not run into the same problem. Another remark is that the loss tensor for the generator is indeed being computed (as displayed in the output above, right before the error message breaks the execution), but the gen_tape simply does not differentiate the model weights w.r.t. the loss tensor.

I tried several things before turning to help online. For instance, I made the gradient tape manually track tensors needed to compute the loss, by first not watching accessed variables by default (watch_accessed_variables=False) within the generator’s gradient tape, and then manually tracked each tensor needed for computations, using the gen_tape.watch() method. I also tried coding the gradient tapes differently, in nested and sequential implementations. Other solutions I tried involved sub-classing my model with the Keras API and overwriting the model.fit() method with my custom training step, in order to use the Keras API to train the networks — still the same error.

Current status:

An exhaustive search online for similar issues was somewhat informative, yet unfruitful. The bottom line was always the same: ‘computations are not being traced because x’ or ‘gradient tape loses track of variables because of y’, however it is not clear which exact computation in my training step breaks the GradientTape, as it very closely inspired by a TensorFlow tutorial of a DCGAN architecture, where the training loop and loss functions are both implemented in the same way. Only difference with our project here concerns the models and data used.

For now, the only choice I’m left with is to start from scratch and re-implement the whole project in PyTorch, hoping that similar issues are not encountered while doing so. Meanwhile, I am still trying to solve the error by consulting experts and the online community at large, however nothing promising has come out of it as of yet. It goes without say, that any insight you may have on this issue would be deeply appreciated by the author of this article. For now, I hope that you have found this progress update interesting, and I will be sure to update you once again, if I am made aware of a solution to this issue.

Written by: Niloy Purkait

--

--

Niloy Purkait

Data Scientist | Strategy consultant | Machine and Deep learning enthusiast. Interests range from computational biology and theoretical physics to big data tech