Using large NLP model’s in production
--
Bidirectional Encoder Representations from Transformers (BERT) is a Transformer-based machine learning technique for natural language processing (NLP) pre-training developed by Google [1]. BERT model or other NLP models that are trained on large datasets utilises more memory. For example BERT-base requires around 450 MB of memory, whereas BERT-large requires around 1.2 GB of memory.
While using any of these NLP models for fine tuning our dataset, we save multiple versions of the fine tuned dataset. And in production we might require to load any of the version’s in runtime. As a general practice we might consider to load the NLP model every time a web request is made, whereas it might not be feasible in low compute production environment. Loading a BERT model from model from mongoDB to local and then again loading the local model to memory can be time-consuming. So in this post I would like to describe my approach in dealing with large NLP models in production.
For demonstration purpose I will be using a simple Scikit-learn model.
Dependencies
scikit-learn
joblib
PyMongo
Let’s create a simple sklearn model [2]:
import joblib
from sklearn import svm
from sklearn import datasetsclf = svm.SVC()
X, y = datasets.load_iris(return_X_y=True)
clf.fit(X, y)model_name = 'mymodel_v1'
model_fpath = f'{model_name}.joblib'
joblib.dump(clf, model_fpath)
Now the model is saved locally, now we need to save it to DB so that we can load from the DB in production. For demonstration purpose I will be using MongoDB.
# internal
import datetime
# external
import gridfs
import pymongo# create mongo client to communicate with mongoDB
mc = pymongo.MongoClient(host='220.24.52.190',
port=27017)
# load or create database
mydb = mc.test_database
# load / create file system collection
fs = gridfs.GridFS(mydb)
# load / create model status collection
mycol = mydb['ModelStatus']
# save the local file to mongodb
with open(model_fpath, 'rb') as infile:
file_id = fs.put(
infile.read(),
model_name=model_name
)
# insert the model status info to ModelStatus collection
params = {
'model_name': model_name,
'file_id': file_id,
'inserted_time': datetime.datetime.now()
}
result = mycol.insert_one(params)
Now the model is saved in the mongoDB which can be retrieved during production. While retreiving the model from the DB we will be following the Singleton design pattern using metaclass. The following is the code of the base class:
class ModelSingleton(type):
"""
Metaclass that creates a Singleton base type when called.
"""
_mongo_id = {} def __call__(cls, *args, **kwargs):
mongo_id = kwargs.pop('mongo_id')
if mongo_id not in cls._mongo_id:
print('Adding model into ModelSingleton')
cls._mongo_id[mongo_id] = super(ModelSingleton, cls)\
.__call__(*args, **kwargs)
return cls._mongo_id[mongo_id]
The code to load the model is as follows:
class LoadModel(metaclass=ModelSingleton):
def __init__(self, *args, **kwargs):
self.mongo_id = kwargs['mongo_id']
self.clf = self.load_model()
def load_model(self):
print('loading model')
f = fs.find({"_id": ObjectId(self.mongo_id)}).next()
with open(f'{f.model_name}.joblib', 'wb') as outfile:
outfile.write(f.read())
return joblib.load(f'{f.model_name}.joblib')
Now we need to check only the mongo _id for any changes in the model versions. The code to get only the _id from the DB is as follows:
result = mycol.find({"filename": model_name}, {'_id': 1})\
.sort('uploadDate', -1)
if result.count():
mongo_id = str(result[0]['_id'])
The code to load the model into production is as follows:
model = LoadModel(mongo_id=mongo_id)
clf = model.clf
Now the model will be downloaded from DB only whenever there is a change in the DB, otherwise the model will be taken from the memory.
Happy coding !!!
References:
[1] https://en.wikipedia.org/wiki/BERT_(language_model)
[2] https://scikit-learn.org/stable/modules/model_persistence.html