The Startup
Published in

The Startup

Photo by Samir Bouaked on Unsplash

Multi-Digit Sequence Recognition With CRNN and CTC Loss Using PyTorch Framework

Theory

An Optical Character Recognition (OCR) task is quite an old problem dated back to the 1970s when the first omni-font OCR technology has been developed. The complexity of this task comes from many natural features of texts:

  • In some alphabets (Arabic, for example, especially in the cursive form) letters are much harder to locate and recognize.
  • There are many different fonts and styles, some of them make characters look too similar (like the letters I and l or the number 0 and letter O).
  • Handwritten text comes in all shapes and sizes and even the most advanced tools like Google Cloud Vision might not recognize all the letters correctly.

All deep learning OCR methods can be roughly divided into three broad categories:

  • Characted-based (Bissacco et al., Jaderberg et al.): these methods first try to find specific locations of individual characters, recognize them and then group into words.
  • Word-based (also by Jaderberg et al.): this kind of methods solve text recognition as a word classification problem, where classes are common words in a specific language.
  • Sequence-to-sequence: these methods treat OCR as a sequence labeling problem. One of the earliest works on this type of methods were written by He et al., Shi et al., and Su et al. This paper by Shi is the original work on CRNN model used in this article and gives us the most thorough and intuitive description of this architecture. Also this paper provides more elaborate overview of the specific GRU-CNN architecture from the computational standpoint. Various modifications of CRNN models perform better than others on many reference OCR datasets.

CRNN architecture

In essence, the CRNN model is a combination of convolutional neural network (CNN) and recurrent neural network (RNN).

CNNs are special family of neural networks that over last 20 years provides SotA performance in various computer vision tasks. There are countless articles and papers that explained in details how CNN works, but for those of you who missed it, a quick recap might be helpful, so I’ll provide it below.

RNNs are capable to capture in its hidden state temporal and spatial features within a sequence of inputs, so they are suitable for tasks like handwriting and speech recognition.

So, here are the key principles of CRNN architecture:

  • Let’s start with CNN. Unlike dense neural networks, convolutional layers are able to capture the relationships between pixels using a set of so-called filters or kernels — weight matrices that slide over the image with a certain stride and calculate the convolution operation at each step.
  • Sometimes batch normalization or instance normalization layer is used after CNN layer. The main purpose of this kind of layers is to standardize our input features to each have a zero mean and variance of one. Both types of normalization are quite similar, except for the number of input tensors that are normalized together — all outputs across the batch and within each separate image channel or only within each separate image channel. There is a great article on the various types of normalization. Intuitively, normalization makes the surface of the loss function smoother which helps it converge faster and more stable, but, as far as I know, this has not yet been proven.
  • In CRNN, fully-connected layers at the end of CNN are not used, instead, the output from the convolutional layers (feature maps) is transformed into a sequence of feature vectors. These vectors are then fed into some type of bidirectional RNN (GRU in our case). Gated Reccurrent Unit (GRU) is a slightly simpler version of LSTM architecture that has comparable performance, but requires much less computational power. This part of the model produces a probability distribution for each feature vector and each label. For example, in the model from the practice section below, the output from the last instance normalization layer has the shape (batch_size, 64, 4, 32) where dimensions are (batch_size, channels, height, width). Then we need to permute dimensions to (batch_size, width, height, channels) so that channels is the last one. According to the original paper, “Specifically, each feature vector of a feature sequence is generated from left to right on the feature maps by column. This means the i-th feature vector is the concatenation of the i-th columns of all the maps.” After that, resulting tensor is reshaped to size (batch_size, 32, 256) and fed into the GRU layers. They produces tensor of shape (batch_size, 32, 256) which is passed through fully-connected layer and log_softmax function to return the tensor of shape (batch_size, 32, 11). This tensor for each image in the batch contains probabilities of each label for each input feature. It is used to compute CTC loss, which is explained in the next section.

CTC loss function

Unfortunately, when we are dealing with the output produced by CRNN described above, we cannot use the regular cross-entropy function, because for each target input it contains multiple sequences of probabilities.

For example, for the target sequence [5, 3, 8, 3, 0] and blank value of 10, the raw predictions (after taking the indexes with the maximum probability) can be as follows:

[10, 5, 10, 10,  3, 10, 10,  8, 10, 10,  3, 10, 10, 10, 0, 0, 10][10, 5,  5, 10,  3,  3,  3, 10, 10,  8, 10, 10,  3, 10, 0, 0, 10]

After removing all the blank values and grouping repeating values together, both outputs will be valid and the loss function needs to figure it out somehow. A Connectionist Temporal Classification Loss, or CTC Loss, was designed for such problems.

Essentially, CTC loss is computed using the ideas of HMM Forward algorithm and dynamic programming. To visualize the main idea, it might be helpful to construct a table, where X axis represents time steps and Y axis represents the output sequence. Blank elements are added at the beginning and the end of the sequence, as well as between all the elements, so the resulting length is 2|l| + 1.

Let’s assume that the transitions between nodes have the following constrants:

  • Transitions can only go to the right or lower right directions.
  • There must be at least one blank element between the same elements.
  • Non-blank elements cannot be skipped.
  • The starting point must be from the first two elements, the ending point must be one of the last two elements.
Examples of legal (blue lines) and illegal (red lines) transitions between nodes. Legal starting and ending points are denoted by green rectangles.
Visualization of all valid paths for the word “apple”

The total forward loss is the sum of all the probabilities of these paths. We can calculate it quite efficiently by using dynamic programming as shown below.

Forward algorithm

This implementation of forward algorithm is a part of PyTorch unit tests. I’ll try to explain how it adheres the original paper by Graves et al. As mentioned in the beginning of this function, “this directly follows Graves et al’s paper, in contrast to the production implementation, it does not use log-space”. Let’s follow this code step by step.

The key idea of this algorithm is expressed in the paragraph 4.1 of the original paper:

Now we need to prepare a set of sequences according to the format described in the paper:

# fill target sequences with blank symbol
targets_prime = targets.new_full((2 * target_length + 1,), blank)
# then fill every odd value with target symbol
if targets.dim() == 2:
targets_prime[1::2] = targets[i, :target_length]
else:
targets_prime[1::2] = targets[cum_target_length -
target_length:cum_target_length]

Also we have to convert original inputs from log-space like this:

probs = log_probs[:input_length, i].exp()

The next step is to initialize variables for this dynamic programming algorithm:

# the length is the same as the target sequences
alpha = log_probs.new_zeros((target_length * 2 + 1,))
alpha[0] = probs[0, blank]
alpha[1] = probs[0, targets_prime[1]]

Now we can compute the forward loss according to the formulas 6 and 7 from the original paper:

# this mask is only true when a[current] != a[current - 2]
# please note that every odd element is blank,
# so this condition never holds for them
mask_third = (targets_prime[:-2] != targets_prime[2:])
for t in range(1, input_length):
alpha_next = alpha.clone()
# we always add a[current-1] to a[current]
alpha_next[1:] += alpha[:-1]
# but we add a[current-2] to a[current] only
# when mask condition is true
alpha_next[2:] += torch.where(mask_third, alpha[:-2],
alpha.new_zeros(1))
alpha = probs[t, targets_prime] * alpha_next

Here, since we have a loss value for every input sequence, we need to convert it to log-space and then compute the result, either by summing all the losses, or by calculating their mean:

# to evaluate the maximum likelihood error, we need the natural logs # of the target labelling probabilities
losses.append(-alpha[-2:].sum().log()[None])
output = torch.cat(losses, 0)
if reduction == 'mean':
return (output / target_lengths.to(dtype=output.dtype,
device=output.device)).mean()

elif reduction == 'sum':
return output.sum()
output = output.to(dt)
return output

Full reference implementation is available here. Now we can test it against built-in CTC Loss implementation. As you can see, results are identical:

ctc_loss = torch.nn.CTCLoss()
# lengths are specified for each sequence in this case, 75 total
target_lengths = [30, 25, 20]
# inputs lengths are specified for each sequence to achieve masking # under the assumption that sequences are padded to equal lengths.
input_lengths = [50, 50, 50]
# target sequences are represented as a tensor of size
# (sum(target_lengths)). Each element in the target sequence is a
# class index. In the (sum(target_lengths)) form, the targets are
# assumed to be un-padded and concatenated within 1 dimension.
targets = torch.randint(1, 15, (sum(target_lengths),),
dtype=torch.int)
# the logarithmized probabilities of the outputs -
# tensor of size (T, N, C), where
# T = input length,
# N = batch size,
# C = number of classes.

log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
res = ctc_loss(log_probs, targets, input_lengths, target_lengths)expected = ctcloss_reference(log_probs, targets,
input_lengths, target_lengths).float()
# output:
# tensor(4.0334)
# tensor(4.0334)

Here are some very useful articles and lectures that helped me to understand various aspects of the CTC loss:

  • I think it’s a good idea to start exploring CTC loss with this article. It provides very intuitive matrix representation of the RNN output and how it’s processed by CTC operation, along with the bird-eye view on the problem itself.
  • This guide is a little more complicated and formal than previous, but it gives us very detailed explanation of the problem, calculations of alignments, and how CTC relates to other commonly used algorithms for sequence modeling.
  • This lecture by Carnegie Mellon University is an ultimate guide to CTC loss, I think it’s by far the best resource on this subject.
  • And last but not least, original paper by Graves et al. provides us the base algorithm for CTC loss calculation and all the formal mathematical calculations behind it.

Practice

The main goal of this practice section is to concatenate multiple images from EMNIST dataset and learn how to recognize that sequence of digits using CRNN architecture and CTC loss funсtion described above. This example is written using PyTorch 1.7.0 and EMNIST dataset from torchvision framework. The project is also available on my GitHub.

Now we need to prepare our dataset. A single training sample is a tensor obtained by combining five normalized and transformed EMNIST images:

Transformation and normalization EMNIST images to the form required by CRNN model

This GRU-CNN model follows directly from the theoretical section:

GRU-CNN (CRNN) model
Training loop and validation
Model sanity test

Results

As you can see, this model performed well despite being smaller than the original one due to computational limitations:

Training and validation accuracy
Trained model test results

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store