Fine-tuning Mistral 7B with QLoRA for new knowledge learning

Luca Canale
Sarus Blog
3 min readJul 11, 2024

--

This post is a brief technical report evaluating the capacity of a large language model (LLM) to learn new knowledge it has not encountered before. To achieve this, we use the medical_extended dataset we introduced in a previous post, which contains a collection of patient/doctor messages about fictional diseases, symptoms, and medications.

Model Parameters

We use a frozen and quantized Mistral 7B-Instruct, and rely on an in-house implementation but adopt the tokenizer from the mistral-common package. Low-Rank Adaptation is applied to both the attention linear layers and the multi-perceptron layers with a rank of 128 and an alpha of 256. We adopt the initialization from the original paper (we tried Pissa as an alternative method, but did not find any significant performance improvements).

Fine-tuning

The fine-tuning process involves this key parameters and methods:

  • we use rather large batches of size 512
  • bfloat16 are used to reduce the memory overhead but keep precision
  • the learning rate is set at 1e-5
  • the model is trained on an extract of the dataset consisting of 9000 examples
  • the evaluation process is done on a separate set of 400 examples
Train and Validation losses

What is nice: Learning capacity

The model demonstrates a high level of accuracy in diagnosing diseases and recommending appropriate drugs. In 94% of the test examples, the model successfully identifies the correct disease and suggests the right medication. This high success rate indicates the model’s strong potential in medical applications and capacity to learn completely new knowledge. The maximal accuracy is reached when the validation loss is minimal as one would expect intuitively.

What Is Problematic: memorization

However, the model also shows a strong tendency to memorize information from the training set. At low sampling temperatures, the model produced answers that were very close to those in the training set. Specifically, three sampled answers match completely the real ones from the training data. On average, 40% of the text in the model’s answers matches the training set, with a standard deviation of 25%. This indicates of course a major concern for privacy.

In fact, we also did some surprising fine-tuning experiments on smaller subsets of the dataset: if we reduce the size by a factor of 10, the diagnosis maximal accuracy and minimal perplexity on the test set do not occur at the same training step anymore as one would expect: we need to train more and when the model is already overfitting (validation loss increases), we get the maximal accuracy). This means that the model is memorizing some other aspects of the data faster than generalizing on disease discovery

Train and validation losses on a small subset of 1000 examples (10 per disease split in 800/200). In red, is plotted the accuracy on the disease diagnosis.

This post is one in a series of posts on AI and privacy. How to use AI and in particular commercial LLMs (for in-context learning, RAG or fine-tuning) with some privacy guarantees but also how AI and LLMs can help us solve privacy challenges. If you are interested in knowing more about existing AI with privacy solutions contact us and try our open-source framework: Arena (WIP).

See also:

--

--