A Non-Technical Introduction to Transformers
In this article I will endeavour to demystify Transformer deep learning networks, the technology behind ChatGPT and the explosion in new AI driven products.
Inspired by ChatGPT, generative AI is popping up everywhere — from Microsoft Office and Google products to new apps daily. Governments are wondering how to limit AI due to safety concerns and simultaneously are worried they will miss out on economic growth if they don’t invest in its development!
Generative AI is surfacing in healthcare products at a pace exceeding regulation which puts the onus on clinicians and administrators to decide how and if this technology is safe to use.
This article will provide a brief overview of how transformer deep learning networks have changed the field of AI and a high-level intuition of how they work.
Because this is a high-level introduction, I will not cover embedding vectors, positional encoding, or how attention heads work. I will review these concepts in future articles.
What’s Different About Transformer Networks?
Before transformer networks, we had developed methods for building statistical models, predictive analytics and AI using numbers and categorical (structured) data. However, despite decades of research, computers cannot understand language and have no automated way to understand commonsense reasoning or build knowledge about the world.
For example, today, using an electronic medical record (EMR) or a patient management system (PMS), a doctor can quickly get a historical plot of lab results or a graph of categorical data captured from questionnaires. But a doctor cannot ask the EMR/PMS to summarise a patient’s last ten years of cardiac history. She will have to go through previous notes page-by-page, the same as a doctor 100 years ago — except on a computer screen rather than paper, and now with an order of magnitude more data to review.
Transformers give computers the power to understand language, learn facts, build abstract concepts about these facts, and even demonstrate types of reasoning about what they have learned. Amazingly they do this automatically by processing vast volumes of text and with the simple task of predicting the next word in every sentence.
To demonstrate how the power of these models, I asked ChatGPT-4 to describe a medical concept (HDL to LDL cholesterol ratio) to a non-medical person:
The output shows how the transformer model translates the understanding of the medical concepts to a new traffic-related analogy.
It’s important to note that while transformers can reason with the information they have seen during training, they cannot derive new facts or concepts.
How Can Predicting the Next Word in a Sentence Lead to Knowledge & Reasoning?
Let’s imagine we are training a transformer on text, and it is presented with the sentence:
I went for a walk with my dog, but I stopped at a coffee shop because
I had pain in my ________
It can use all prior words as context at this point in the sentence. The transformer identifies which words are more important than others based on their context in the sentence — a technique called self-attention. It uses this context information to trigger its learnings to predict the next word. Some options could feasibly include:
Latte? Dog? Leg? Chest?
We wouldn’t consider some of these options because we implicitly use common sense knowledge, but a computer without common sense might. Let’s assume the transformer has never seen this sentence before, so it can’t use a previously memorised answer. The network must generalise facts into concepts to predict the next word accurately. Some concepts it could use are:
- Only living things can feel pain (hence not the latte)
- Pain will reduce function in body parts
- Legs are used for movement
- Chest pain is more serious than leg pain, and coffee shops do not deliver healthcare
So it will predict “leg” even though it has never seen this specific sentence before. Interestingly, if you change “coffee shop” to “hospital”, ChatGPT suggests “chest” or “stomach” because it’s more likely to be a serious condition.
Somewhat surprisingly, deep learning models (and some non-deep learning models) tend to generalise based on common patterns in the training data, even if they have the capacity to memorise the data [1].
The Origins of Transformer Models
The original paper describing the transformer architecture was published in 2017 by researchers from Google [2]. The initial goal was to improve language-to-language translation by neural networks.
An unintended consequence of the architecture was that these models could learn facts, concepts and reason across these facts and concepts.
While transformers are synonymous with generative AI, popularised by ChatGPT, the early use of the technology was for text classification for search using the BERT family of models and the many variants that followed (CliniBERT, RoBERTa, and the cleverly named French variant CamemBERT).
The most common uses for transformers for clinical uses are:
- Predictive or discriminative, where clinical notes can be used for prediction and classification tasks, and
- Generative, which generates text based on some query or prompt, including summarisation.
Due to the propensity for transformers to hallucinate, I believe the safest way to use this technology in healthcare in the short term is for predictive tasks.
Elements of a Transformer Network
When describing transformers networks, we often hear about their size, usually in billions (B) of parameters. For example, GPT-3 is known to be 175 B parameters, Google’s original PaLM model 540 B parameters. Recently we learned GPT-4 contains about 1,700 B parameters (although in 8 groups of 220 B).
A parameter is a value (or weight) adjusted during training that defines model behaviour.
Consider a regression model, which is a relatively simple statistical model. Let’s say we want to predict renal impairment using six patient factors we have collected for a sample of patients. We would load the data into a statistical package and ask it to use an algorithm to learn the weights for each risk factor that, when combined, would provide a risk prediction.
In the model shown, the algorithm will learn the weights for the six parameters (ignoring an intercept parameter). When predicting the risk for a new patient, we enter the patient’s value for each factor; the regression model multiplies the value against the learned weight, and the orange node in the diagram sums the results of each factor ✕ weight to generate a risk probability.
Transformers are similar to billions of these little models connected, containing billions of parameters that they adjust through training against billions or trillions of words. It may cost millions of dollars in cloud fees to train a model. This intensive training results in a foundation model.
The term token is often used to describe the input to transformers. A token is part of a word. For example, “unhappiness” may result in the tokens “un-”, ”happi” and ”-ness”. Tokens enable the models to infer the meaning of words they haven’t seen before by examining their parts.
Foundation models contain vast amounts of information but aren’t good at any specific task. Rather than using expensive re-training to specialise them, researchers fine-tune models at a much lower cost by adding a smaller set of fine-tuning layers.
How Transformers Work
The diagram below is a conceptual representation of a transformer network. The network has many layers connected by weights whose values have been adjusted during training. Different layer types perform various functions in the network.
When the model receives a prompt (some text), it activates input nodes. Nodes in the following layer sum up the value of each input node ✕ weight for each connection, and if the value is greater than a threshold, it will, in turn, activate the connections in the next layer.
This cascade continues to the network’s last layer, which selects the token (or word) it will output. The generated word augments the original prompt to create a new input; the process repeats until the model finishes with an end-of-output token.
Fine-tuned models have additional fine-tuning layers trained for a specific task, avoiding the need to re-train the foundation model. These fine-tuned layers influence the activation of paths within the model and can modify the output.
Problems with Transformers
Unfortunately, transformers have significant problems, especially when used for generative tasks. These include::
- Hallucination — the models may generate plausible by completely incorrect output when propagation through the network takes a wrong path (below).
- Data Provenance — the models do not differentiate high-quality data from low-quality data during training.
- Uncertainty — they do not express uncertainty, so when presented with prompts with content outside of their original training data, the models will continue to generate outputs with confidence.
- Bias — deep learning networks are descriptive systems that learn from existing data sources and therefore learn existing bias and problems with equity.
- Black Box — these models cannot explain their reasoning; knowledge in deep learning models is not localised to an identifiable set of nodes or weights in the network, making it hard for users to judge if they should trust the model.
These are severe weaknesses and potential “show-stoppers” for healthcare applications. However, in the broader context of an intelligent system, transformers could be a component that works with other modules that can compensate for these problems.
Conclusions
Transformers are one of the most significant advances in AI. Interacting with these models through a natural language interface is unique, fascinating, entertaining and genuinely helpful in casual use. However, their application in high-stakes industries, such as healthcare, must be carefully tested and controlled.
There are ways to safely utilise the power of transformers in healthcare, for example, in a non-generative mode for predictive tasks that use clinical notes or to detect findings in clinical notes for risk stratification. Researchers also use these powerful models for image recognition and time series analysis.
With the flurry of excitement (and investment) in generative AI, we will see many positive and not-so-positive examples in healthcare over the coming months.
A Note on the Network Diagrams
I created a small JavaScript library using the p5js framework to generate the network diagrams in this article. I’ve made it open source in case you want to generate network diagrams of your own! [3]