2019 Machine Learning Research Internships at Curai

Anitha Kannan
Curai Health Tech
Published in
14 min readDec 6, 2019

--

Curai’s mission is to provide the world’s best healthcare for every human being. We are building an augmented intelligence system to help scale physicians’ abilities as well as to lower users’ barrier to entry to care.

Our internship program brings together students from diverse and strong technical backgrounds. Teamed with mentors, interns work on problems that are important to make deep inroads into our mission. Our internship program is open throughout the year, and there are three different internship openings for students: engineering, research and medical. This post covers the research of our research interns this summer. You can also read about the work of our engineering interns in this other blog post.

We had three awesome research interns this summer: Clara McCreery and Sam Shleifer from Stanford, and Viraj Prabhu from Georgia Tech. It’s worth noting that Viraj Prabhu was a returning research intern! They made tremendous research progress in three main areas: domain relevant embeddings for question similarity, formulation of decoder for medical dialog as a classification task for trading flexibility for control and open set medical diagnosis. We are quite excited to share that the research from their internships have been peer-reviewed and presented at NeurIPS 2019 Machine Learning for Healthcare (ML4H) workshop. While you can read their work in the above links, what follows in this blog post is an edited short version of their work, in their own words.

If you are interested in helping us with our mission, please apply for an internship or on any of our other positions by emailing your resume to jobs@curai.com.

1. For Finding Similar Medical Questions, Domain Matters

Clara McCreery, Stanford

TL;DR: The task of determining whether a given pair of medical questions are “similar” or not is closely related to the task of medical question answering. In fact, these tasks may be more closely linked (in terms of low-level features) than question-similarity tasks in different domains are. We show how the task of classifying whether a statement adequately answers a medical question (medical question answering) works well for transfer learning to the medical question-similarity task. Our work highlights the importance of choosing both, the right pre-training task as well as the domain from which that pre-training data comes. Full paper

This Summer I completed a 10-week research internship at Curai, a start-up in Palo Alto whose goal is to make primary care affordable and scalable. Curai’s technology drives their chat-based platform, Curai Health, where doctors interact with patients. This internship gave me an opportunity to take the lead on an open-ended research question and address the challenges of working with real-world data. Whereas many academic research questions are formulated around datasets that already exist, companies often care about problems for which labeled and vetted data do not yet exist. Not a lot of start-ups prioritize contributing to the research community the way Curai does, so this internship was a rare chance to see how fast-paced start-up culture can coexist with and even benefit from forward-thinking research.

The Problem

During my internship, I tackled the following natural language processing (NLP) problem: How can we accurately detect pairs of semantically similar medical questions? For example, the pairs of questions below should be identified as similar and dissimilar respectively, even though the first pair have few words in common, and the second pair differ by only one word.

Examples of pairs of similar and different questions

There are several reasons why we care about medical question similarity:

  • Some patient-asked questions are factual and will always have the same answer. If these questions have been answered before, it is faster to have an algorithm surface the answer than it is for a doctor to answer the question
  • It is expensive and not a good use of doctors’ time to answer the same questions over and over if the questions do not require follow-up questions or in-person examination
  • Different doctors have different experiences and expertise, and such a system can aggregate knowledge, following the principle that many heads are better than one
  • Collecting information about common primary complaints over time can give scientists insights into the spread of disease, onset of epidemics, etc.

The problem of question similarity is of interest to many beyond the medical field. The question-answering forum Quora is one of many businesses that are always trying to improve their algorithms for detecting and combining duplicate questions. In fact, a common general language understanding benchmark, the GLUE benchmark, includes a task called Quora Question Pairs (QQP), in which a model is supposed to label question pairs as either similar or different. What is unique about this problem in the medical domain is that domain expertise is required to adequately label medical questions as similar or not, and most general-language models have no such expertise. Furthermore, as no large labeled datasets of similar medical questions exists, it is unclear how to best embed this knowledge into a machine learning model.

Transfer Learning

One common technique for overcoming a dearth of training data is transfer learning, in which a neural network is trained first on a large dataset that is related to but different from the task of interest. Then, the weights from that first network are transferred to or used to initialize the weights of a network being trained on the actual task of interest. This often works well because related tasks leverage similar low-level features/weights, and learning these low-level features requires a lot of training data.

Although it is expensive and infeasible to generate a medical-question-pairs dataset large enough to train a neural network from scratch (QQP uses 363,000 question pairs), it is more manageable to generate a small dataset (3,000 question pairs) for fine-tuning and testing our models. We started with 1,500 freely available patient-asked questions from HealthTap, to ensure that questions were representative of what real patients ask online. For each question, we then had healthcare professionals generate two question pairs: one positive and one negative. The positive pair used different language to ask a similar question, whereas the negative pair used similar language to ask a different question.

We then set out to find a pre-training dataset and task for transfer learning that would produce useful low-level features for our task of medical question similarity. One obvious dataset to choose was the previously-mentioned QQP. Although these questions are not medical, they do represent many examples of semantically similar and different phrases. This is the same task in a different domain.

Additionally, although there are no large datasets of medical question pairs online, there are many medical question-answer pairs from websites such as HealthTap [1] [2]. Some of this data is labeled with the category to which the question belongs. We leveraged this data to generate three more pre-training tasks in the medical domain:

  1. Question-Answer Pairs (QA): Here, positives were questions and the true answer that a doctor provided online. Negatives were questions and a doctor-written answer to a different question
  2. Answer-Answer Pairs (AA): Positives were complete medical answers. Negatives were the first 2 sentences from one medical answer followed by the last several sentences of a different answer
  3. Question-Category Pairs (QC): Positives were medical questions and their correct categorical label. Negatives were medical questions and a random incorrect categorical label

For each pre-training dataset we trained a model first on the pre-training task, and then fine-tuned it on our medical question pairs dataset. We performed these experiments using the BERT model (results below) and replicated them with XLNet.

As shown in the plot (X-axis: number of training examples used in the final task viz. medical question similarity), we found that the best pre-training task was that of Question-Answer (QA) matching. This highlights the importance of both the pre-training task as well as the domain from which that pre-training data comes. For instance, despite AA and QC data being medically relevant, pre-training on QQP data (which is generic) outperforms both of them. However, pre-training on QA data (also medically relevant) outperforms QQP. Furthermore, we also observe that this outperformance is amplified for smaller final-task training sets. To dive further into the cause for this improvement, we looked at the errors made by each model. Our qualitative error analysis shows how some medical synonyms that prompted false negatives in the Quora model (eg. hypertension/high blood pressure and menstrual cycle/period) were learned by the models pre-trained on the medical datasets, which got these questions correct. We present many such examples in our paper. Through our error analysis, we gained insights into the types of mistakes that our best models make.

2. Classification as Decoder: Trading Flexibility for Control in Medical Dialogue

Sam Shleifer, Stanford

Over the summer, I worked with Manish Chablani on AI assisted medical conversation automation. A step in that direction is suggesting responses to doctors they can use or make edits to in medical conversations with users of tele-medicine. Suggested Response is a difficult task because the model must be able to handle a wide range of primary care dialogue scenarios, including courtesy responses (eg. “Feel Better soon”), asking patients relevant questions about their symptoms to narrow down disease diagnosis space, and offering advice on managing the symptoms. Here are some examples:

Some examples of suggested responses

We experimented with generative models based on ULMFit as well as other transformer based architectures including GPT2. Generative models, are trained to predict the next word, and at inference time use beam search to generate a likely set of responses. Generative approaches can generate a likely response in any conversational context, making them incredibly flexible. This flexibility comes at the cost of control. Undesirable responses in the training data will be reproduced by the model at inference time, and longer generations often don’t make sense. For example, in one conversation about drug interactions generative system proposed “It is not advisable to take Sudafed with Sudafed” to be the most likely response.

Our Approach: Classification as Decoder

To avoid ineffective and inaccurate generations, I trained a classifier to choose from a predefined list of full responses. The classifier is trained on (conversation context, response class) pairs, where each response class is a noisily labeled group of interchangeable responses and is associated with an “Exemplar response”. At inference time, the classifier just looks up the exemplar response associated with the predicted class ID in a dictionary.

The discriminative approach has a few key advantages: First, doctors can edit and improve these exemplar responses over time, without retraining the model. For example, if we wanted to switch from recommending `users sleep 6–8 hours per night’ to `recommending 7–9 hours’, we could simply update the message associated with the output class and the discriminative model would immediately generate new advice in the same conversational context, without retraining.

The main drawback, and the reason this approach is not widely used, is that it requires a high-quality (conversational context, response class) training dataset, which requires clustering the responses observed in the data into interchangeable response “classes” that are numerous enough to contain reasonable responses in a wide range of conversational topics, while also having enough examples for each class that the model can learn it.

I spent the first few weeks of the summer on this system, which includes a weakly supervised pipeline that relies on techniques like LSH with pretrained sentence encoders, a supervised BERT-based sentence similarity model, Agglomerative Clustering and three hours of manual labeling. The process is detailed at length in our paper, but eventually generates 187 distinct response classes that appear frequently in the conversational logs. Each response class has a “exemplar” response for that class which is used at inference time as suggested response. We ensure that exemplars are all factual, sensical and grammatical by allowing experts to edit them before or after training.

Experiments

The response clustering problem is unsupervised, and can only be evaluated as part of the end to end response suggestion problem. More specifically, by asking medical experts whether the responses generated by one system are better than responses generated by another system. Luckily, the results were very encouraging: For a random sample of 775 conversations-response pairs, only 12% of our discriminative approach’s responses are worse than the doctor’s response in the same conversational context, compared to 18% for the generative model.

The rest of the summer was spent implementing and comparing different architectures, including triplet loss, Huggingface’s GPT based model, HRED, and ULMfit. At the beginning of the summer we suspected that we would need a more traditional ranker approach, but with experimentation, we discovered that models that used triplet loss or a multiple choice loss took longer to train, spent more time at inference, and did not significantly improve classification accuracy, as shown below. For the final version, we settled on ULMFit, since it performed as well as the larger transformer based approach, but offered much faster training and inference.

Classification Accuracy for 187 classes using either 4 or 8 turns of conversational history

Our paper discusses more experiments on different response class generation procedures, model architectures, and tradeoffs.

3. Open set medical diagnosis

Viraj Uday Prabhu, Georgia Tech

In the US, an increasing number of adults (about a third!)[3][4] use the Internet to diagnose medical concerns and online symptom checkers are increasingly part of this process. These tools are powered by diagnosis models similar to clinical decision support systems, and they walk patients through a series of questions about their symptoms, and finally provide a diagnosis. Such services are poised to revolutionize patient-facing telehealth services that could move from current rule-based protocols for nurse hotlines to more accurate and scalable AI systems.

However, existing models for diagnosis make a closed-set assumption, i.e. the universe of diseases is limited to those that the diagnosis model can understand. In practice, it is infeasible to obtain sufficient training data for every human condition (over 14,025 diagnosis codes and 2000 disease families exist in ICD-10), and so it is highly likely that a diagnosis model, once deployed, will encounter cases corresponding to previously unseen conditions. To address this, telehealth providers often constrain their coverage to a specific area of care. However, determining whether or not a patient falls within diagnostic scope based on symptoms often necessitates additional models or human expertise, which introduces higher costs and is itself error-prone. Further, each misdiagnosis is a missed opportunity for better care, and may even be safety-critical.

Consider the example above, that illustrates a typical symptom checking workflow. For the given list of user-reported findings, that include muscle rigidity, jaw pain, and muscle spasms, a closed-set diagnostic model predicts whiplash injury to be the most likely condition, followed by scarlet fever and diverticulitis. In this case, the underlying condition is actually tetanus, which the diagnosis model has never been trained to diagnose. Such misdiagnosis is an artifact of the ‘forced-choice’ nature of such closed-set models. Further, in this case the model is a deep neural network — while such models tend to achieve state of the art performance on several tasks, they often tend to be highly overconfident about their predictions, which is evidenced in the high confidence (95.3%) assigned to the incorrect prediction in this case. Both of these limitations pose significant challenges to the deployability of such diagnostic models.

Prior work in machine learning has studied the open-set learning problem (and the related problem of learning with reject option), where the goal is to design approaches that are aware of and can avoid misclassifying previously unseen classes. Inspired by this, we frame machine-learned diagnosis as an open-set classification problem, that we call Open Set Diagnosis, and study how well proposed approaches apply to diagnosis. In the example above, our goal is to design an open-set diagnosis model that predicts “Don’t know” for the given set of findings, and recommend additional diagnostic evaluation (such as physical exams, lab tests, and imaging studies), rather than misdiagnose.

Another critical challenge in building diagnosis models is access to data. Health data usually lives in hospital repositories and for privacy reasons can often not be taken outside its source site to be pooled with other sources. This makes training diagnosis models (particularly data-inefficient neural networks) difficult. Further, healthcare sites often have data that is complementary — for example, hospitals on the US east coast are likely to have far more patient encounters for hypothermia than on the west. To develop models with high accuracy and coverage, we need mechanisms to bridge models trained on separate sites. To this end, we introduce the task of Ensembled Open Set Diagnosis, where we ensemble models trained on data sources that cannot be shared, and evaluate their open-set diagnosis performance.

Concretely, we study two tasks that are illustrated below. In Task 1 (Open Set Diagnosis), our goal is to learn a model to diagnose a select set of diseases (call this Lselect), and reject unseen conditions (call this Lunknown), i.e declare “Don’t Know”. We are optionally given access to case data for a subset of conditions Lextra that is disjoint from Lselect, and that we may use for modeling unseen conditions. In task 2 (Ensembled Open Set Diagnosis), the goal and evaluation setting is identical to Task 1; however, training data is now distributed across multiple healthcare sites.

For our dataset, we simulate clinical cases using the simulation algorithm proposed in this prior work from Curai, and create splits for Lselect, Lextra, and Lunknown. We measure performance as the diagnostic accuracy corresponding a given false positive rate (i.e. rate of misdiagnosing an unknown condition as a known one). Our goal is to achieve a high accuracy at a low false positive rate. We evaluate over a large and challenging testset of over 200k cases corresponding to diseases in Lselect and Lunknown.

We experiment with three different training approaches: i) a simple cross-entropy (CE) baseline where we threshold on the confidence of the final softmax layer to predict “Don’t Know” ii) training with an additional background (BG) class as a catch-all for unseen conditions, the extra class being trained with examples from Lextra, and iii) a state of the art open-set learning method called the Entropic Open-Set (EOS) loss, which trains with regular cross-entropy for seen (Lselect) datapoints, and encourages high entropy otherwise (for Lextra). For training ensembles in Task 2, we experiment with both a naive ensembling strategy (predicting the class with the highest confidence) and a learned strategy where we train an additional network to combine the outputs from individual expert models that are trained at different sites. All our models are parameterized as 2-layer neural networks, and we represent cases with one-hot encodings using a global symptom vocabulary.

Our results are summarized above. On Task 1, we find the BG approach to clearly outperform EOS. In Task 2, with naive ensembling we find the same trend to hold. However with learned ensembling, EOS outperforms BG (not shown here). In Task 2, we also report the corresponding performance of the method on Task 1 as an “oracle” upper bound, with the gap representing the error introduced by distributed training. Across all settings, approaches that model unseen conditions (BG and EOS) consistently outperform a baseline that does not (CE). Further, we find that open-set approaches do not perform any worse at closed-set diagnosis (i.e. when only seen conditions occur at test time). For additional experiments, results, and analysis, please read our paper here.

This work only scratches the surface and several exciting directions of follow-up work remain. These include varying model families across different sites, disease distribution shift across sites and deployments. All of these are essential challenges to overcome on the road to building reliable models for diagnosis.

Finally, I’d like to thank all my collaborators on this work (Anitha, Geoff, Namit, Manish, David, and Xavier) for their guidance through this project, and to everyone at Curai for the awesome internship experience!

--

--