Recently I gave a talk on the present and future of machine / deep learning (ML / DL)in the enterprise. In an enterprise setting, the questions and topics are more applied than a research conference — for example how do I get my team started in ML and how best to integrate ML with existing systems we run. Afterwards, in the panel discussion the topic of Java and machine learning came up.
Java is noticeably absent in the machine learning space. Virtually no ML frameworks are built in Java (there is DL4J but I genuinely don’t know anyone who uses it, MXNet has a Scala API but not a Java one and it’s not written in Java, Tensorflow has an incomplete Java API), however Java has a massive footprint in the enterprise, representing perhaps trillions of dollars invested globally over the last 20 years in every conceivable domain — financial services, trading, e-commerce, telcos — the list is endless. For machine learning, the “first among equals” programming language is not Java, but Python. Personally, I enjoy coding in both Python and Java, but Frank Greco posed an interesting question that got me thinking:
What would Java need to compete with Python in ML? What if Java got serious about truly supporting machine learning?
Is it important?
Let’s motivate this discussion. Since 1998, Java has been at the top table as far as multiple enterprise evolutions and revolutions has been concerned — the web, mobile, browser vs native, messaging, i18n and l10n globalisation support, scaling out and supporting every kind of enterprise information store you care to mention, from relational databases to Elasticsearch.
That level of unconditional support prompts a very healthy can-do, roll-up-your-sleeves-and-code-it kind of culture on Java teams. There is no magic component or API that cannot be augmented or replaced by a good Java team.
But this is not the case in machine learning. Instead, Java teams have two options:
- Re-train / co-train in Python.
- Use a vendor API to add a machine learning capability to your enterprise system.
Neither of these options is really that palatable. The first requires significant upfront time and investment plus ongoing maintenance costs, while the second risks vendor lock-in, vendor de-support, introduces third-party components (with a network hop price to pay) to perhaps a performance-critical system and requires you to share data outside your organisation boundaries — a no-go for some.
Most damaging of all in my opinion, is the potential for cultural attrition — teams can’t change code they don’t understand or can’t maintain, so there is a dereliction of duty and someone else picks up the heavy lifting— Java-only teams run the risk of being left behind in the next big wave coming for enterprise computing — the machine learning wave.
So it is important and desirable for the Java programming language and platform to have first-class machine learning support. Without it, Java runs the risk of being slowly replaced by languages that do support ML well over the next 5–10 years.
Why is Python so dominant in ML?
First, let’s consider why Python is the dominant language in machine learning and deep learning.
I suspect it all started from a fairly innocuous feature — slicing support for lists. This support is extensible: any Python class implementing the __getitem__ and __setitem__ methods can be sliced using this syntax. The snippet below shows how powerful and natural this Python feature is..
Of course, there’s more. Python code is more terse and brief when compared to older Java code. Exceptions are supported but are unchecked, and developers can easily code throwaway Python scripts to try stuff out without getting sucked into the Java mindset of “everything is a class”. It’s easy to get cranking with Python..
But now the main factor in my opinion — although the Python community made a dogs dinner in maintaining cohesion between 2.7 and 3, they did a far better job in building a well-designed, fast numeric computing library — NumPy. Numpy is built around the ndarray — the N-dimensional array object. Directly from the docs: “NumPy’s main object is the homogeneous multidimensional array. It is a table of elements (usually numbers), all of the same type, indexed by a tuple of positive integers”. Everything in NumPy follows from getting your data into an ndarray and then performing operations on it. NumPy supports multiple types of indexing, broadcasting, vectorisation for speed and in general allows developers to easily create and manipulate large arrays of numbers.
Dealing with large multi-dimensional arrays of numbers is the beating heart of machine learning coding, and especially deep learning. Deep neural networks are lattices of nodes and edges modelled by numbers. Run-time operations when training a network or performing inference on it require fast matrix multiplication.
Drawbacks to Python
Python isn’t a perfect language — in particular the most popular Python runtime — CPython — has a Global Interpreter Lock (GIL) so scaling is not straightforward. Moreover, Python DL frameworks like PyTorch and Tensorflow still hand off core methods to opaque implementations. For example, the cuDNN library from NVidia has a profound influence on the scope of the RNN / LSTM implementation in PyTorch. RNN and LSTM is a very important DL technique for business applications in particular as they specialise in classifying and predicting over sequential, variable-length sequences — e.g. web clickstreams, text snippets, user events and more.
To be fair to Python, this opacity / restriction applies to almost any ML / DL framework that is not written in either C or C++. Why? Because in order to obtain the maximum performance for core, hig-frequency operations like matrix multiplication, developers go as “close to the metal” as they can.
What does Java need to do to compete?
I propose that there are three primary additions to the Java platform which if present, would prompt the sprouting of a healthy and thriving machine learning ecosystem in Java:
1. Add native indexing / slicing support in the core language to compete with Python ease of use and expressive power, probably centred around the existing ordered collection List<E> interface. This support would also need to admit overloading to support point #2.
2. Build a Tensor implementation— probably in the java.math package but also bridging across to the Collections API. This set of classes and interfaces would act as the equivalent of the ndarray, and would provide additional indexing support — specifically the three types of NumPy indexing: field access, basic slicing and advanced indexing essential for coding ML.
3. Support broadcasting — scalars and Tensors of arbitrary (but compatible) dimensions.
If these three things were in place in the core Java language and runtime, it would open up the road to build “NumJava”, the equivalent to NumPy. Project Panama could also be used to provide vectorised, low-level access to fast tensor operations running on CPUs, GPUs, TPUs and beyond to help Java ML be the fastest around.
I’m not suggesting these additions are trivial — far from it, but the potential upside to the Java platform is enormous.
The snippet below shows how our NumPy broadcasting and indexing example could look in NumJava with a Tensor class, with slicing syntax supported in the core language, and respecting the current restriction on operator overloading.
A vision and call to action
We all know ML is going to change the world of business, just as the relational database, internet and mobile did. There’s plenty of hype around, but equally some papers and results are emerging that are strongly compelling. This paper for example, promises a future where our optimal database, web and application server configs will be learned and tuned in the background using machine learning. You won’t even have to deploy ML to see it in your systems, one of your vendors is sure to do it for you.
From the pragmatic starting point outlined in this post, we could have as many machine / deep learning frameworks written in Java and running on the JRE as we have web, persistence or XML parsers — just imagine that! We could envisage Java frameworks with support for convolutional neural networks (CNNs) for leading-edge computer vision, recurrent neural network implementations like LSTM for sequential datasets (essential to business), with cutting-edge ML features such as auto-differentiation and more. These frameworks would then engender and power the next generation of enterprise-class systems — all running seamlessly beside existing Java systems, using the same tooling — IDEs, testing frameworks, continuous integration. And importantly, being written and maintained by the same people. If you’re a fan of Java, what’s not to like about that?