Modeling Patient Health Using Transformers

Ashwini Kelkar
apree health (Castlight) Engineering
6 min readApr 11, 2023

Internship project building a Transformer-based model to classify Castlight members into “at-risk” health categories.

Internship and Project Background

Last summer, I was a Data Science Intern at Castlight Health. I had the opportunity to work on Castlight’s Genius Classification Engine. The Engine hosts a suite of predictive models that places Castlight Members into “at-risk” segments — at risk of developing certain health conditions such as diabetes or requiring medical procedures such as cardiac surgery. Trained on a feature-rich dataset gathering member data from a variety of sources, the Engine drives many use cases at Castlight such as personalized content on a member’s home screen and email communication. Care Guides, Castlight’s team of experts who provide individualized one-on-one assistance to members, also leverage these predictions to match members with the best care. All of this with the ultimate goal of closing gaps in care and preventing Castlight customers from developing those health conditions.

For my project, I gained inspiration from the way the Genius Engine has been set up. Today, twenty-plus models are utilized by the Engine, each responsible for predicting a segment. As the Engine grows, it increases the operational overhead of creating and maintaining each individual model added to it. There is also an opportunity to experiment with new features that could potentially improve the models. Further, some features were identified that either require a constant refresh or are now redundant. My task was to design, build, and train a deep-learning model to answer some of these issues. In this blog, I describe a few features of the proposed solution to this unique problem.

I was introduced to the Transformer model, a type of neural network architecture that was introduced in the publication “Attention is All You Need” by Vaswani et al. in 2017. A few noteworthy features of this architecture are presented in the paper:

  • Transformers make use of self-attention mechanisms. Self-attention allows the model to weigh different parts of the input sequence at each step, based on the task at hand. This allows the model to capture long-range dependencies and relationships between parts of the input sequence.
  • Transformers implement multi-head attention, by applying self-attention many times in parallel. This allows the model to look at multiple aspects of the input sequence simultaneously. This adds to the model’s ability to capture complex relationships in the input sequence.
  • Transformers have the ability to process input sequences of any length in a parallel manner. Thus the model is well-suited for tasks that require handling long input sequences.

By design, Transformers also have the ability to be fine-tuned for specific tasks such as in this case, where we perform classification. Next, I briefly describe the features of the model and training pipeline.

Multi-label classification

The Genius Engine is a collection of models, each responsible for predicting a segment. The goal is to learn from a dataset prepared using patient biographical data (such as age) and medical history to place a member in the appropriate at-risk segment. This gives rise to a multi-label classification problem. In usual classification problems, the classes to predict are mutually exclusive whereas, in our task here, our model is trained to predict more than one non-exclusive task. This is implemented in Deep Learning models by setting the number of nodes in the last (output) layer of the model as the number of target classes. In this case, this will be the number of “at-risk” classes we wish our model to learn.

Moreover, for the design of the network, I explored a branching architecture, also known as a multi-branch or a multi-task architecture. The network, upon accepting the input, branches into multiple paths. Each path accepts the same input, learns a specific task and makes a prediction based on this input. The predictions of these paths are combined to produce the network’s output. The loss is computed on this final output. In this health-segment modelling task, each condition was learned by a separate branch, and their predictions were combined into a vector giving the output of the model. Designing the network this way has some advantages that include improved accuracy and performance. This also gave me the flexibility of designing each branch specific to the condition being modelled.

Modeling with sequential inputs

Castlight receives patient data from a variety of sources. Claims-related information is contained in ICD codes, which is the international classification of diseases, a clinical cataloguing system. An ICD code can be either a diagnostic code or a procedure code and consists of two parts. For example, the code E08.01 — The first part E08 represents Diabetes Mellitus. The entire code E08.01 represents Diabetes Mellitus with coma. Hence the granularity of information being learned from the code itself can be controlled.

A sequence of ICD claim codes represents a patient’s medical history. This idea led to the formulation of this problem as a sequence classification task. The dataset constructed has characteristics exhibited by sequential data:

  1. The order in which the codes appear in the sequence is important in understanding the data
  2. The length of a sequence varies from person to person depending on the patient’s medical history
  3. The codes in the sequence are related to each other, each in some way dependent on the previous code

The other novel sequential feature that was constructed is a measure of the recency of each observed code as the number of days, which is the reference date minus the service date or the date on which the claim was submitted. This information is fused with the ICD code and passed into the transformer’s embedding layer.

Transformer Embeddings

Embeddings are a key component in Transformer models. In general, embeddings allow us to represent high-dimensional input data in a lower-dimensional space by capturing the most relevant features present in the original input. In transformers, input sequences are transformed into fixed-length sequences of embeddings and then passed to the later layers of the Transformer model.

For our health classification model, the embedding is a learned representation of ICD codes wherein codes that are related to each other have a similar representation. Embeddings in the transformer architecture as introduced in the paper “Attention Is All You Need” are of two kinds — position and token-based. Position embeddings represent the position of the ICD code seen in the sequence. The token embeddings are learned with training, capturing the meaning of the code in the sequence. These embeddings are combined and passed to deeper layers of the model.

Tensorflow has implemented the Tensorboard to visualize high-dimensional embeddings graphically. It was interesting to see how the ICD codes are represented in this embedded space. For example, codes commonly seen for patients with a specific condition are seen to lie close to each other or have similar representations in this machine-learned embedded space. There is also an opportunity to reuse embeddings. Embeddings once trained on as many codes as possible across health segments will not have to be trained again and again; only periodically when the codes have either been revised or new codes introduced. The more data the model trains on, the richer the insights and interpretations we can derive from the embeddings.

Frigure: Visualizing the 10 nearest codes to M54 (Dorsalgia) among 1763 codes

Results, Future Scope, and Conclusion

This single multi-label transformer-based model performed comparably with the current Genius Segment models. The results were compared using a subset of health segments chosen for this exercise using metrics suitable for this classification problem. The performance will only improve over time with more training data. With this deep learning model, multiple opportunities to perform hyper-parameter optimisation also exist.

Deep Learning models can ingest a variety of features. With this architecture, there is scope for easy integration of different kinds of features — complex inputs such as EHR data or more noisy inputs from fitness tracker data. Further, there is certainly value in obtaining insights from the learned embeddings. Analysis of these can help us understand how the codes and, as an extension, the segments are correlated.

I’m grateful to Castlight Health and my mentors Robert Stewart and Jeff Hendricks for giving me the opportunity to work on and learn from an exciting project such as this. As someone who is passionate about healthcare, I learnt a great deal about the industry and the challenges that exist within it during my internship.

References

--

--