Target Class-Labels in Prediction result of Tensorflow Estimator API in ML Engine

Johny Jose
Analytics Vidhya
Published in
3 min readFeb 6, 2019

Introduction to the Problem

This is a very specific problem I was facing while configuring a Serving model using tensorflow estimators API in Cloud ML Engine. If there are multiple lines in the prediction input and you need the list of target classes mapped to each output line, it is quite tricky. The predictions from the trained graph can have output parameters in a configured way. For example, it can include the predicted class or label name in for a multiclass classifier, the probabilities list for the classes, accuracy or what not. But one thing that I found tricky was to add the actual list of class or the label names with the prediction result. This would be useful in a case when a deployed model is used by someone who has no idea what the classes are and he can get information about the probabilities for them.

There seems to be no straightforward way to add the class names in the prediction output list. We are able to get the probabilities associated with each class from the softmax layer and using tf.gather we can obtain the predicted class name.

But what if you wanted to get the list of class names as well in the predicted result. After a lot of frantic search through places like StackOverflow, I came to the conclusion no one ever came across this particular problem before. So after a lot of trial and error, I came up with this solution (it may be a bit over-engineered).

Solution

The probabilities can be obtained from the softmax layer in the neural network using the tf.nn.softmax function on the logits.

probabilities = tf.nn.softmax(logits)

Now to obtain the target label or the class name associated with the highest probability we can use the tf.gather method with a given a list of target names. The predicted index is obtained using the tf.argmax method and using tf.gather and the list of targets the name of the target is obtained.

predicted_indices = tf.argmax(probabilities, 1)predicted_class = tf.gather(TARGET_LABELS, predicted_indices)

Now comes the tricky part of getting the list of names and adding it to the predicted result object. First I used tf.where to get all the indices of all probability values with the condition as all values less than 2, which is all the probabilities values. So now have all the indices, I found the last index and formed a new list that contains the indices and then reshaped to that of the probabilities. Now, this new reshaped matrix contains all the indices.

If target classes are:-

['Class A', Class B']

And the predicted probabilities will be in the form of a 2x2 matric if there are two inputs:-

[[.5, .4], [.4, .5]]

The way the outputs will be obtained from CloudML Engine will be a list of objects for multiple lines of predictions. To map to this case first a similarly shaped matrix of indices are formed and using tf.gather, the class names are also converted to a matrix in the form:-

[['Class A', Class B'], ['Class A', Class B']]

This will allow the result of the list of objects to be properly mapped. The code that does the explanation is:-

condition = tf.less(probabilities, 2)#all the indices are obtained
indices = tf.where(condition)

# get the last index from the result
last_index = indices.get_shape().as_list()[1] - 1
#the new list containing all the indices --> [0, 1, 0, 1]
last_indices_value = tf.slice(indices, [0, last_index], [-1, -1])

# reshape the result to the correct format
# [0, 1, 0, 1] --> [[0, 1], [0, 1]]
classes_shape = tf.reshape(last_indices_value, tf.shape(probabilities))

# form the classes list with the indices in the new shape
classe_names = tf.gather(TARGET_LABELS, classes_shape)

Conclusion

So that is the over-engineered way to map or include a list of class names in the prediction that will work with any number of lines in prediction inputs. This was mainly for me to remember this method and how I went about solving it and for anyone who has by any chance came across a similar problem.

Please do comment if you have an alternate or simpler way of doing the same. Thank you.

--

--