Clojure MXNet for musculoskeletal disease diagnosis

Back in May, we released a proof-of-concept deep learning webapp using Cortex, the native Clojure deep learning library.

Now that Carin Meier has built a Clojure Package for MXNet we have decided to jump in and use it for the MURA (musculoskeletal radiographs) dataset, one of the largest public radiographic datasets, made available for the community by Stanford University.

At the start of this project Clojure MXNet documentation was quite sparse although it’s getting much better now (there are for example web docs) so I’d like to share some code and explanations about it.

The object of this project (called Xenon) is to reproduce Stanford’s results by recreating their model described in the paper (subsection 3.1) using the Clojure MXNet API.

All the code is available on our bitbucket and was run in the container using a previously published docker image with all the dependencies needed to run MXNet on a CUDA-enabled GPU. You can also see it in action on our webapp.

Landing page

Preparing data

To get the MURA dataset you need to read and agree to the Stanford University School of Medicine MURA Dataset Research Use Agreement and register, after which you will receive a link to download the dataset.

MXNet has several different data loading APIs, and I used ImageRecordIter. To do so we need to create required .rec RecordIO files from the downloaded dataset by using the im2rec tool. Data needs to be organized in a specific way to use im2rec. Additionally at this point I wanted to split the provided valid part into test and val for validation during training and testing afterwards. So my desired file structure looked like this:

data
├── test
│ ├── abnormal
│ │ ├── 00000.png
│ │ ├── 00001.png
│ │ ├── …
│ │
│ └── normal
├── train
│ ├── abnormal
│ └── normal
└── val
├── abnormal
└── normal

To do this I created a prepare.clj file and wrote a few helper functions and a final build-image-data function to put MURA images into the desired structure.

Finally we have to use the im2rec tool. Based on this tutorial I created a new rec folder and created the .lst image list files, first using --list flag and data we saved in the way im2rec expects.

python resources/im2rec.py --list --recursive rec/mura-train data/train
python resources/im2rec.py --list --recursive rec/mura-val data/val
python resources/im2rec.py --list --recursive rec/mura-test data/test

Having those we can at last create the binary .rec files.

python resources/im2rec.py --num-thread 16 rec/mura-train data/train
python resources/im2rec.py --num-thread 16 rec/mura-val data/val
python resources/im2rec.py --num-thread 16 rec/mura-test data/test

This will result in 3 .rec files our image-record-iter function will use. I’ve put them in a separate data-iter.clj file so we can use them for training and testing afterwards. Properties of those iterators are determined by merging some maps to avoid repeating code.

As a side note, since we are not actually doing any modification to the images using im2rec, we could use only .lst files instead of .rec .

Training

We won’t be training our network from scratch. Instead, following the paper, we will use transfer learning with the same pre-trained 169-layer DenseNet. Fortunately, since Clojure MXNet is compatible with any other API of MXNet, we can use the implementation done by miraclewkf and trained with python API. For this we need to download two files: densenet-169-symbol.json (JSON file with the symbol i.e. structure and connections between layers in MXNet namespace) and densenet-169–0000.params (binary with trained weights, parameters) and put them into model folder.

To load it in Clojure I used a simple helper function:

(defn get-model
([]
(get-model x-params/models-prefix))
([model-prefix]
(let [model (m/load-checkpoint {:prefix model-prefix :epoch 0})]
{:msymbol (m/symbol model)
:arg-params (m/arg-params model)
:aux-params (m/aux-params model)})))

In order to use it for binary prediction we need to modify the symbol by removing the last layer and adding our new one. We leave untouched the layers up to flatten0 which, as the name suggests, is the flattening layer before the fc1 fully connected one. We replace it with fc1 of our own, with 2 hidden nodes for the 2 possible outputs. We finish up with the standard softmax-output output layer.

(defn get-fine-tune-model
[{:keys [msymbol arg-params num-classes layer-name]
:or {layer-name "flatten0"}}]
(let [all-layers (sym/get-internals msymbol)
net (sym/get all-layers (str layer-name "_output"))]
{:net (as-> net data
(sym/fully-connected "fc1" {:data data :num-hidden num-classes})
(sym/softmax-output "softmax" {:data data}))
:new-args (->> arg-params
(remove (fn [[k v]] (string/includes? k "fc1")))
(into {}))})

Alternatively, we could use a single hidden node and use sigmoid activation function, but I didn’t found it in the API. I did use that approach in the weighted version (more about it later).

Now we can plug in our train and validation data iterators into the model.

(defn init-model
[devs msymbol arg-params aux-params]
(-> (m/module msymbol {:contexts devs})
(m/bind {:data-shapes (mx-io/provide-data x-data/train-iter)
:label-shapes (mx-io/provide-label x-data/val-iter)})
(m/init-params {:arg-params arg-params
:aux-params aux-params
:allow-missing true})))

To mimic 100% the model described by Stanford ML Group, MXNet is missing weighted cross entropy. I tried implementing it myself with some success, but I couldn’t figure out how to add that extra information about study types (weight position) from iterators to the model. For now, we are using standard cross entropy (which isn’t perfect, since, especially for some study types, the data is quite unbalanced).

That being said, we are training with adam optimizer with the same parameters as described and there was no problem in that regard.

At the end we also need to save the module to create JSON symbol and binary params to load it later for future use or predicting.

(defn fit
[model]
(m/fit model
{:train-data x-data/train-iter
:eval-data x-data/val-iter
:num-epoch x-params/train-epochs
:fit-params (m/fit-params {:intializer (init/xavier {:rand-type "gaussian"
:factor-type "in"
:magnitude 2})
:optimizer (optimizer/adam {:learning-rate
x-params/learning-rate
:beta1 0.9
:beta2 0.999})
:batch-size x-params/batch-size
:batch-end-callback (callback/speedometer
x-params/batch-size 10)})})
(m/save-checkpoint model {:prefix x-params/saved-mod-prefix :epoch 0 :save-opt-states true}))

The first thing to do in MXNet is choosing context i.e. device(s) where the computations will take place, CPU(s) or GPU(s). So in the above code snippet we are combining both loading and training in thefine-tune! function which we will call to do the actual training on the specified context device list.

(defn fine-tune!
([]
(fine-tune! x-params/default-context))
([devs]
(let [{:keys [msymbol arg-params aux-params] :as model} (get-model)
{:keys [net new-args]} (get-fine-tune-model (merge model {:num-classes 2}))
model (init-model devs net new-args arg-params)]
(fit model))))

All of this is in train.clj

At this time by calling fine-tune! the training process will begin. You should be able to follow it in your terminal:

[16:04:30] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:107: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
[16:04:32] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:107: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [10] Speed: 114.16 samples/sec Train-accuracy=0.829545
INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [20] Speed: 121.07 samples/sec Train-accuracy=0.836310
INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [30] Speed: 124.22 samples/sec Train-accuracy=0.840726
INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [40] Speed: 120.12 samples/sec Train-accuracy=0.851372
INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [50] Speed: 121.30 samples/sec Train-accuracy=0.854167
INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [60] Speed: 109.07 samples/sec Train-accuracy=0.862193
WARN org.apache.mxnet.WarnIfNotDisposed: LEAK: [one-time warning] An instance of org.apache.mxnet.Symbol was not disposed. Set property mxnet.traceLeakedObjects to true to enable tracing
INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [70] Speed: 111.77 samples/sec Train-accuracy=0.863556

In the image above the Speedometer Callback prints progress every 10 batches and, as you can see, our training accuracy increases.

There is also a warning message about memory leak after a few batches and because of it this training will crash when we run out of memory. I didn’t found a solution for it, other than resuming training from the last checkpoint when it happens.

Predicting

For prediction, the paper mentions using an ensemble of 5 networks. I did this part using core matrix library to calculate average response of each MXNet module used for predictions.

Prediction here expects to find all the images of one study in a single folder, the same way they are stored in the original MURA directory.

That’s mostly pure Clojure here. We initialize the module the same way we did for training (except now we don’t need to change anything, since we are loading our custom module symbol). Then we need to change the image into NDArray to pass it into our module(s). We take the average response for each image in the study and each module to calculate the final result.

And that’s about it. I hope this effort helps people using Clojure MXNet API in their endeavors.

Why Xenon?

In case you are curious about the name of the project, at Magnet we are naming our own projects after the chemical elements (Hydrogen coming next). Xenon gas was introduced in computed tomography imaging as well as single-photon emission computed tomography because of its advantages over other techniques that were being used (BMC).

Clojure MXNet: looking forward

Xenon has been a very exciting project to work on, being a first for us using Clojure MXNet (cheers to Carin Meier!). We will continue to build more solutions based on this technology and, as always, we are happy to receive feedback, opinions, questions… do not hesitate to let us know!