Graphs and ML: Remembering Models

A Developer’s Log

Lauren Shin
Neo4j Developer Blog
6 min readJul 11, 2018

--

I wrote in my first article about the linear regression procedures I added for Neo4j. Today I want to explain some of the internals, and why I chose to build them the way I did.

User-defined procedures must “remember” information between calls in order to build and maintain a machine learning model. This pushes beyond the typical functionality of procedures in Neo4j. In the following article, I explore the key implementation details for a set of functions and procedures that create, train, test, use, and store linear models of graph data. I hope this helps your understanding of my work or for implementing similar user-defined procedures of your own.

To see these user-defined functions and procedures in action, read my previous post in which I use linear regression to create a price predictor for short term rentals in Austin, TX.

I take a brief writing break to illustrate the awesomeness of user-defined procedures in Neo4j.

The Goal

I want to perform linear regression on the data I have stored in Neo4j. There are many libraries with tools for creating linear models (Pandas and Numpy for Python, Commons Math for Java, etc.), but to use these directly I must export data from my graph into another software. In order to model my data from within Neo4j I need to extend the functionality of the graph query language Cypher.

The preferred means for extending Cypher is with user-defined functions and procedures. These are written in Java, built for instance with Maven, deployed to the Neo4j database as JAR files, and called from Cypher. I need to write a procedure that performs linear regression on graph data.

What makes this problem interesting?

First, let’s take a look at a typical procedure in Neo4j: apoc.meta.graph. Running CALL apoc.meta.graph() will return a visual representation of the graph’s schema (your data model). For example, here is the result of this procedure call on the short term rental graph from my previous post:

This procedure accesses the information stored in the graph in order to determine the underlying structure. Other procedures alter the graph, such as apoc.refactor.mergeNodes which merges multiple nodes into just one. However, procedures do not typically remember information between calls, they just produce a stream of outputs or modify data.

I can create a procedure that, like apoc.meta.graph, accesses the information in the graph in order to create a linear model without storing any external information. I can pass in all data at once, perform the least squares calculations in the procedure, and return the parameters of the linear model. But if I make another call to the procedure, it will have already forgotten the model it just created.

But what if I then decide to add more data to the model? What if I want to use a large amount of training data that requires too much memory to be input as the argument to one procedure?

Idea #1: Serialize!

My first attempt at a solution is a bit of a workaround because instead of “remembering” information, the procedure stores the model in the graph so it can be accessed and updated later. The idea is to serialize the model’s Java object and store the byte array in the graph between procedure calls. Here’s a visual representation of the process of serialization and deserialization:

http://blog.acorel.nl/2017/09/using-serializable-objects-in-abap.html

Note: I am using SimpleRegression from the Apache Commons Math library. SimpleRegression performs updating calculations on incoming data points so that no individual data points are saved. Instead, it stores certain information such as mean “y” value, total number of data points, etc. and updates these values with each new data point. Thus with each additional data point the model performs calculations and improves without increasing its memory usage. The result: when we serialize the SimpleRegression object the corresponding byte array is not very large (at least, it doesn’t scale with size of the data set!).

I first wrote the following helper functions so that throughout my project I could convert the SimpleRegression object to byte[] and vice versa. These required imports from java.io.*.

//Serializes the object into a byte array for storage
static byte[] convertToBytes(Object object) throws IOException {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutput out = new ObjectOutputStream(bos)) {
out.writeObject(object);
return bos.toByteArray();
}
}
//de serializes the byte array and returns the stored object
static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException {
try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
ObjectInput in = new ObjectInputStream(bis)) {
return in.readObject();
}
}

Then, next time I wanted to edit the model, I retrieved the byte array from the graph and deserialized it.

try {
ResourceIterator<Entity> n = db.execute("MATCH (n:LinReg {ID:$ID}) RETURN n", parameters).columnAs("n");
modelNode = n.next();
byte[] model = (byte[])modelNode.getProperty("serializedModel");
R = (SimpleRegression) convertFromBytes(model);
} catch (Exception e) {
throw new RuntimeException("no existing model for specified independent and dependent variables and model ID");
}

And, after editing the model, stored the new byte[] representation back in the same node.

try {
byte[] byteModel = convertToBytes(R);
modelNode.setProperty("serializedModel", byteModel);

} catch (IOException e) {
throw new RuntimeException("something went wrong, model can't be linearized so new model not stored");
}

If you’re interested, check out the full code. Note that these preliminary implementations are much different (and more convoluted!) than the final version of my linear regression procedures below.

Issues

What if I want to, for the sake of design, separate create model, add data, and remove data procedures? Graph databases are constantly receiving updates, so I need to create a model that is as flexible as the graph. Serialization and deserialization has a significant time requirement. You may pass in multiple data points at once in order to limit the number of procedure calls (and the number of times the model is stored and retrieved), but I need some better way to store the intermediate model between procedure calls so that it may be updated as many times as I need.

Idea #2: Static Map

Static variables in the Java classes implementing the procedure live as long as the database continues running. Therefore, we can store model objects by name in a static map. Models are stored in the procedure so that each step of linear regression — create, add data, remove data, etc. — is isolated into a separate procedure but alters the same SimpleRegression model. Something like an add procedure can be called once for each data point without severe performance penalties. This creates the simplified design we desire. With every step isolated the procedures are clear and the user has greater control over each step of building the linear model.

The models are stored in a static ConcurrentHashMap in one of the Java classes used to implement the procedures: LRModel.java. Whenever a procedure is called that needs to access the model, it is retrieved from models by name using the method from. Using that particular map implementations allows for concurrent access from multiple threads.

private static ConcurrentHashMap<String, LRModel> models = new ConcurrentHashMap<>();static LRModel from(String name) {
LRModel model = models.get(name);
if (model != null) return model;
throw new IllegalArgumentException("No valid LR-Model " + name);
}

Now we only have to serialize and store the model before database shutdown and load it back into the procedure’s static memory when the database is restarted. Check out the full implementation on Github.

Limitations

  • If the database shuts down unexpectedly the static variables will be cleared and the model lost. It might be a good idea to have a backup option in which the model is serialized and saved at regular time intervals. After database failure, restart the database and rebuild the model.
  • Serialization is not the best way to save the model because if anything is changed in the next version of Commons Math, the updated version may not recognize a serialized SimpleRegression object from before.
  • Stats for test data are not stored between database shutdown/restart. Ideally, I would implement the updating simple regression myself instead of using Commons Math, and then store all necessary information about training and testing data instead of storing a serialized SimpleRegression.

Improve my work!
If you have ideas more stable than storage in static maps and serialization, let me know or implement it yourself! I challenge you to improve my work. I would love to discuss possibilities, just message me on LinkedIn or @ML_auren. Cheers!

--

--