Keras Hyperparameter Tuning using Sklearn Pipelines & Grid Search with Cross Validation

Amine Benatmane
3 min readAug 16, 2019

--

Training a Deep Neural Network that can generalize well to new data is a very challenging problem. Furthermore, Deep learning models are full of hyper-parameters and finding the optimal ones can be a tedious process !

Fortunately, Sklearn Grid Search is here to save us!

Keras Wrappers for the Scikit-Learn API

To perform Grid Search with Sequential Keras models (single-input only), you must turn these models into sklearn-compatible estimators by using Keras Wrappers for the Scikit-Learn API: [refer to docs for more details]

A sklearn estimator is a class object with fit(X,y) , predict(x)and scoremethods. (and optionnaly predict_proba method)

No need to do that from scratch, you can use Sequential Keras models as part of your Scikit-Learn workflow by implementing one of two wrappers from keras.wrappers.scikit_learnpackage:

  • KerasClassifier(build_fn=None, **sk_params): which implements the Scikit-Learn classifier interface.
  • KerasRegressor(build_fn=None, **sk_params): which implements the Scikit-Learn regressor interface.

Arguments

  • build_fn: callable function or class instance the should construct, compile and return a Keras model, which will then be used to fit/predict.
  • sk_params: model parameters & fitting parameters.

Note that like all other estimators in scikit-learn, build_fn should provide default values for its arguments, so that you could create the estimator without passing any values to sk_params.

Example: Text classification with IMDB movie reviews dataset

In this example, we show how to combine Sklearn Pipeline, GridSearch and these Keras Wrappers to fine-tune some of the hyperparameters of TfidfVectorizer and a basic Sequential keras model on the IMDB movie reviews dataset from Kaggle.

Link: https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews

You can create a Kaggle Kernel directly from this dataset :)

Or fork my kernel on Kaggle: https://www.kaggle.com/med92amine/keras-hyperparameter-tuning

Import some usefull libraries

Loading Data using pandas

Loading Imdb dataset
Dataset Shape and class distribution

Perfect! This Dataset is “class Balanced” !

Do Some Text Preprocessing

Remove punctuations, numbers, whitespaces, urls, … You know, cleaning thing !

Function to remove punctuation
Function to remove urls
Function to remove html tags
Clean Text
Convert Targets to int

Implement Keras Model creator function

We want to fine-tune these hyperparameters: optimizer, dropout_rate, kernel_init method and dense_layer_size.

These parameters must be defined in the signature of create_model() function with default parameters. You can add other hyperparameters if you want such as learning_rate, ...

binary_crossentropy is perfect for Two-class classification problem.

Implementing Keras Model creator function

Create sklearn-like estimator

It’s a classification problem so we are using KerasClassifier wrapper.

Creating Keras Classifier

Tuning some TF-IDF Hyperparameters

We need to convert the text into numerical feature vectors to perform text classification. We will be using Sklearn TfidfVectorizer in this example.

We could use Keras Text Preprocessing Tokenizer for that, but we want to do some TF-IDF Hyperparameters fine-tuning. So, we create a Sklearn Pipeline for that:

Sklearn Pipeline

Defining Hyperparamers Space

We define here our hyperparameters space including keras fit hyperparameters: epochs and batch_size:

Hyperparameter Space

And Finally Performing Grid Search with KFold Cross Validation

It’s same as grid search with sklearn; it’s no big deal!

Remember, For K-fold cross validation, K is not a hyperparameter. The purpose of cross-validation is not to come up with a final “performant model” but to see how well our model is able to generalize well to unseen data against a relevant performance metric.

Instansiate Grid Search
Fit Grid Search

And Voila !

Perfroming Grid Search

Bonus: Tune the number of dense layers and their sizes

You can find an excellent example with mnist dataset on keras github repo: https://github.com/keras-team/keras/blob/master/examples/mnist_sklearn_wrapper.py

--

--