TensorFlow: Combining Categorical and Continuous Variables

Illia Polosukhin
3 min readOct 28, 2016

--

I originally planned to go over some examples in TF.Learn for natural language problems, but somehow work on TF.Learn itself got me busy.

In previous Part 3 of this tutorial we have reviewed how to add categorical variables to your model.

But the other day, @philbort filed a bug https://github.com/ilblackdragon/tf_examples/issues/7 and asked how to combine continues and categorical variables in one model.

After trying to do it quickly in response, I realized it’s quiet hard. Now, we are going to fix some of that in upcoming changes to TensorFlow, but in a meanwhile I want a way to put together various pieces of TF.Learn to achieve the goal and also talk about some new concepts.

Let’s start with input function — function you can pass to your fit / predict as alternative to x and y data arrays. Idea here is that you want build piece of the graph that would read and sample your data instead of keeping it always in memory. For example if you have a csv file, you can write an input function like this:

def my_input_fn():
examples = tf.learn.graph_io.read_batch_examples(
'my.csv', 32, tf.TextLineReader)
header = ['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex',
'Age', 'SibSp', 'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked']
record_defaults = [[1], [1], [1], [''], [''], [1], [1], [1], [''], [1.0], [''], ['']]
cols = tf.decode_csv(examples, record_defaults=record_defaults)
features = zip(header, cols)
target = features.pop('Survived')
return features, target

We use read_batch_examples to setup a reader (TextLineReader) that would read lines from my.csv and batch them into a string tensor of [32]. Then we call decode_csv, which parses each string in tensor into list of columns. We define number and dtypes of this tensors by providing record_defaults. Finally we return features (string to tensor) and target tensor.

Additional things to know about input functions — is that depending on flags to read_batch_examples this may return data infinitely (as long as we are asking for it) or for specific number of epochs. And it randomizes order by default, to disable it pass randomize_input=False.

Now as you see features would contain both continues features like Fare and Age as well as string features like Sex and Pclass. There are different ways to handle this, for example using tf.contrib.lookup.HashTable to build a lookup table right in the graph.

Here I want to talk about alternative path — doing everything with Pandas and then passing already preprocessed data into the model. This is less scalable (e.g. won’t work in distributed environment very well), but works for local training.

Currently (2016/10/27) there is a limitation if x is DataFrame, what will model receive (due to legacy reasons, it translates it into a matrix).

To work around it, we will write an input function that would feed preprocessed DataFrame in correct format. Then we will write a model that can use already mapped categorical variables into indices together with continues variables.

Here we write pandas_input_fn that uses learn.dataframe.queues.feeding_functions.enqueue_data to feed DataFrame into the model (e.g. adds nodes in the graph that in parallel are fed with data) in separate thread. This also should work faster then passing x, y into fit because it doesn’t lock training loop to fetch new records.

Now in our model function, we use process features: continues ones are mapped to float and reshaped into [batch_size, 1]. Categorical features are all embedded using different embedding matrices (see Part 2 for more details about embeddings). Then all this features are concatenated into one feature vector and passed into deep 3-layers neural network. The later part is the same as in previous models.

The final results are better the either just categorical or just continues variables, getting after a bit of training:

Accuracy: 0.7821

As always, you can find all code on github: https://github.com/ilblackdragon/tf_examples. Feel free to create an issue or file a pull request!

See Text Classification post on how to leverage this to solve real problems.

Since writing this post, I founded NEAR Protocol. Read more about our journey.

--

--

Illia Polosukhin

Co-Founder @ NEAR Protocol. Simple. Secure. Scalable. I'm tweeting as @ilblackdragon.