Implementing RoI Pooling in TensorFlow + Keras

Jaime Sevilla
xplore.ai
Published in
10 min readApr 1, 2019

--

by Jaime Sevilla @xplore.ai

In this post we explain the basic concept and general usage of RoI (Region of Interest) pooling and provide an implementation using Keras layers and the TensorFlow backend.

The intended audience for this post are people familiar with the basic theory of (Convolutional) Neural Networks and who are capable of building and running simple models using Keras.

If you are here just for the code, serve yourself from this gist and do not forget to like and share the article!

Understanding RoI Pooling

RoI Pooling was proposed by Ross Girshick in the Fast R-CNN paper as part of his object recognition pipeline.

In the general use case for RoI Pooling we have an image-like object, and multiple regions of interest specified via bounding boxes. We want to generate an embedding out of each RoI.

For example, in the R-CNN set up we have an image and a proposal mechanism that produces bounding boxes for potentially interesting parts of the image. Now, we want to embed each suggested patch of the image.

Simply cropping each suggested region will not work, because we want to stack the resulting embeddings on top of each other, and the suggested regions do not necessarily have the same shape!

So we need to come up with a way of transforming each patch in a way that results in a embedding of a predefined shape. How can we do that?

A standard way in Computer Vision of reducing the shape of images is using some sort of pooling operation.

The most common one is max pooling, where we divide the input image in (usually non-overlapping) areas of equal shape, and form the output by taking the maximum value found in each area.

The maxpool operation divides each region into pooling areas of the same size

This does not directly solve the problem we have — patches of different shapes will be divided into a variable number of areas of equal shape, producing an output of variable shape.

But this inspires an idea. What if we divide each region of interest into an equal number of areas of variable different shape, and take the maximum of each of them?

The ROI Pooling operation divides all regions into equivalent grids of pooling areas

And that is exactly what the ROI Pooling layer does.

Use cases for RoI pooling

RoI pooling is a fairly general tool. It usually shows up in conjunction with RoI proposal mechanisms, bridging the gap between region proposals and their embedding. We will see two concrete examples to illustrate its potential.

Firstly, in the context it was developed (object recognition) it allows us to split the task pipeline into two parts (region proposal and classification of regions) while retaining an end-to-end single-pass differentiable architecture.

Fast R-CNN architecture showcasing RoI pooling, by Ross Girshick

Thus in the R-CNN model we first have a component in the model that proposes a fixed number of RoIs. RoI pooling allows us to run a CNN classifier on all the proposed RoIs and pick as our output the one with the highest target class probability.

Secondly, together with region proposals, RoI Pooling can also be used to implement visual attention.

Attentional Network for Visual Object Detection showcasing RoI pooling, by Hara et al

As an example in Attentional Network for Visual Object Detection we see how Hara et al. implement attention using iterative RoI proposals and RoI pooling. They first generate a first proposal (t=1) which the RoI pooling adapts to the Fully Connected layers hereunder. The result is used as an input to the Glimpse component (t=2) to generate a new region proposal that is again embedded using RoI pooling. The process is repeated T times.

Typing it

Before we dive into the implementation let’s pause for a minute to think about the type signature of the ROI layer.

It takes in two tensors:

  • A batch of images. In order to be able to process them together, all images must have the same shape. The resulting shape of the tensor will be (batch_size, img_width, img_height, n_channels).
  • A batch of Region Of Interest (ROI) proposals. If we want to stack them together in a tensor, the number of proposed regions must be fixed for each image. Since each bounding box must be specified with 4 coordinates, the shape of this tensor will be (batch_size, n_rois, 4).

And it must output:

  • A list of embeddings for each image, codifying the regions specified by each ROI. The corresponding shape should be (batch_size, n_rois, pooled_width, pooled_height, n_channels).

The code in Keras

Keras allows us to easily implement custom layers via inheritance of the base Layer class.

The tf.keras documentation recommends implementing the __init__, build and call methods for our custom layer. However, since the purpose of the build function is to add the weights of the layer and our layer has no weights, we do not need to override that method. We will also implement the convenience method compute_output_shape.

We will code each part separately and then put all the pieces together in the end.

def __init__(self, pooled_height, pooled_width, **kwargs):
self.pooled_height = pooled_height
self.pooled_width = pooled_width
super(ROIPoolingLayer, self).__init__(**kwargs)

The constructor of the class is quite simple to understand. We need to specify the target height and width of the embeddings we are producing. In the last line of the constructor we call the parent constructor to initialize the rest of the class attributes.

def compute_output_shape(self, input_shape):
""" Returns the shape of the ROI Layer output
"""
feature_map_shape, rois_shape = input_shape
assert feature_map_shape[0] == rois_shape[0]
batch_size = feature_map_shape[0]
n_rois = rois_shape[1]
n_channels = feature_map_shape[3]
return (batch_size, n_rois, self.pooled_height,
self.pooled_width, n_channels)

compute_output_shape is just a nice utility function that will tell us what the output of the layer will be for a particular input.

Next we have to implement call. The call function is where the logic of the layer lives. This function should take as input the two tensors that hold the input to the ROI Pooling Layer, and should output the tensor with the embeddings.

Before we implement that, we need to implement a simpler function that will take a single image and a single ROI and return the corresponding embedding.

Let’s do that step by step.

@staticmethod
def _pool_roi(feature_map, roi, pooled_height, pooled_width):
""" Applies ROI Pooling to a single image and a single ROI
"""
# Compute the region of interest
feature_map_height = int(feature_map.shape[0])
feature_map_width = int(feature_map.shape[1])

h_start = tf.cast(feature_map_height * roi[0], 'int32')
w_start = tf.cast(feature_map_width * roi[1], 'int32')
h_end = tf.cast(feature_map_height * roi[2], 'int32')
w_end = tf.cast(feature_map_width * roi[3], 'int32')

region = feature_map[h_start:h_end, w_start:w_end, :]
...

The first six lines of the function are computing where the region of interest starts and finishes within the image.

We have chosen as a convention that the coordinates of each ROI should be specified in relative terms, as numbers between 0 and 1. Concretely, each ROI is specified by a 4-dimensional tensor containing four relative coordinates (x_min, y_min, x_max, y_max).

We could have decided to identify it in absolute terms, but this is generally worse, since a common pattern is to pass the input image through some convolutions that change the shape of the image before feeding it to the ROI Pooling layer, which would force us to keep track of how the image shape changes to scale the ROI bounding boxes properly.

The seventh line just crops the image down to the region of interest using the super powerful tensor slicing syntax TensorFlow provides us.

...# Divide the region into non overlapping areas
region_height = h_end - h_start
region_width = w_end - w_start
h_step = tf.cast(region_height / pooled_height, 'int32')
w_step = tf.cast(region_width / pooled_width , 'int32')

areas = [[(
i*h_step,
j*w_step,
(i+1)*h_step if i+1 < pooled_height else region_height,
(j+1)*w_step if j+1 < pooled_width else region_width
)
for j in range(pooled_width)]
for i in range(pooled_height)]
...

In the next four lines we compute the shape of each area within the ROI that is going to be pooled.

Afterwards we create a 2D array of tensors where each component is a tuple indicating the coordinates of the start and end of each of the areas we are going to take a maximum.

The code that generates the grid of area coordinates seems overly complicated, but notice that if we just divide the region of interest into areas of shape (region_height // pooled_height, region_width // pooled_width) there would be some pixels of the ROI that would not fall within any area.

We fix that by extending the right and bottom most areas to encompass the remaining pixels that would not fall in any area by default. This is realized in the code via conditional declarations of the max coordinates of each bounding box.

The result is a 2D list of bounding boxes, and we move on to the next part.

...# Take the maximum of each area and stack the result
def pool_area(x):
return tf.math.reduce_max(region[x[0]:x[2],x[1]:x[3],:], axis=[0,1])

pooled_features = tf.stack([[pool_area(x) for x in row] for row in areas])
return pooled_features

These lines above do all the magic. We define an auxiliary function pool_area that takes as input a bounding box specified by a tuples such as the ones we just created, and outputs the maximum of each channel in the area.

We map pool_area over every area we have declared using a list comprehension.

By this point we return a tensor of shape (pooled_height, pooled_width, n_channels), holding the result of pooling one region of interest of a single image.

Next step is pooling many ROIs from a single image. This is straightforward to implement using an auxiliary function and tf.map_fn to produce a tensor of shape (n_rois, pooled_height, pooled_width, n_channels).

@staticmethod
def _pool_rois(feature_map, rois, pooled_height, pooled_width):
""" Applies ROI pooling for a single image and varios ROIs
"""
def curried_pool_roi(roi):
return ROIPoolingLayer._pool_roi(feature_map, roi,
pooled_height, pooled_width)

pooled_areas = tf.map_fn(curried_pool_roi, rois, dtype=tf.float32)
return pooled_areas

Lastly we need to implement the batch level iteration. If we pass to tf.map_fn a sequence of tensors (such as our input x) it will take care of zipping it under the hood for us.

def call(self, x):
""" Maps the input tensor of the ROI layer to its output
"""
def curried_pool_rois(x):
return ROIPoolingLayer._pool_rois(x[0], x[1],
self.pooled_height,
self.pooled_width)

pooled_areas = tf.map_fn(curried_pool_rois, x, dtype=tf.float32)
return pooled_areas

Notice that we must specify the dtype parameter of tf.map_fn every time its expected output does not match the data type of its input. In general it is good practice to specify it as often as possible to make explicit how the types are changing through our Tensorflow computation graph.

Let’s put everything together:

Let’s test our implementation! We are going to suppose a 100x200 1-channel image and we will extract 2 RoIs using 7x3 pooling patches. Images can have regions classified in a maximum of 4 labels. Example feature map is all 1s but a single value of 50 placed at (height-1, width-3).

import numpy as np# Define parameters
batch_size = 1
img_height = 200
img_width = 100
n_channels = 1
n_rois = 2
pooled_height = 3
pooled_width = 7
# Create feature map input
feature_maps_shape = (batch_size, img_height, img_width, n_channels)
feature_maps_tf = tf.placeholder(tf.float32, shape=feature_maps_shape)
feature_maps_np = np.ones(feature_maps_tf.shape, dtype='float32')
feature_maps_np[0, img_height-1, img_width-3, 0] = 50
print(f"feature_maps_np.shape = {feature_maps_np.shape}")
# Create batch size
roiss_tf = tf.placeholder(tf.float32, shape=(batch_size, n_rois, 4))
roiss_np = np.asarray([[[0.5,0.2,0.7,0.4], [0.0,0.0,1.0,1.0]]], dtype='float32')
print(f"roiss_np.shape = {roiss_np.shape}")
# Create layer
roi_layer = ROIPoolingLayer(pooled_height, pooled_width)
pooled_features = roi_layer([feature_maps_tf, roiss_tf])
print(f"output shape of layer call = {pooled_features.shape}")
# Run tensorflow session
with tf.Session() as session:
result = session.run(pooled_features,
feed_dict={feature_maps_tf:feature_maps_np,
roiss_tf:roiss_np})

print(f"result.shape = {result.shape}")
print(f"first roi embedding=\n{result[0,0,:,:,0]}")
print(f"second roi embedding=\n{result[0,1,:,:,0]}")

The lines above define a test input for the layer, build the corresponding tensors and run a TensorFlow session so we can check its output.

Running the script will result in the following output:

feature_maps_np.shape = (1, 200, 100, 1)
roiss_np.shape = (1, 2, 4)
output shape of layer call = (1, 2, 3, 7, 1)
result.shape = (1, 2, 3, 7, 1)
first roi embedding=
[[1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1.]]
second roi embedding=
[[ 1. 1. 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1. 1. 50.]]

We can check that the tensor shapes match our expected results. The resulting embeddings are all 1s except in the area where we introduced a 50 feature.

It seems to be working!

Conclusion

And that is it for today folks!

Today we have learned what the ROI Pooling Layer does, and how we can use it to implement attention. Furthermore we have learned how to extend Keras to implement custom layers without weights, and gave an implementation of the aforementioned ROI Pooling Layer.

I hope this was useful for you, do not forget to share the article and leave a comment if it was!

Thank you to Ari Brill, Tjark Miener and Bryan Kim for feedback on the article.

References

Ross Girshick. Fast R-CNN. Proceedings of the IEEE International Conference on Computer Vision. 2015.

Kota Hara, Ming-Yu Liu, Oncel Tuzel, Amir-massoud Farahmand. Attentional Network for Visual Object Detection. 2017.

Check our services and explorations with Deep Learning, Machine Learning, Computer Vision and GANs on our LinkedIn page, Twitter, Instagram and don’t forget to follow us on Medium to not miss any post like this in the future.

--

--