Llama-Bitnet | Training a 1.58 bit LLM

Zain ul Abideen
5 min readApr 4, 2024

--

What is 1 bit LLM and How to train 70M Llama-Bitnet?

Llama-Bitnet

Introduction

Vanilla LLMs built upon the Transformer architecture typically operate in 16-bit precision (FP-16 or BF-16) and hence the major computation costs account for the floating point matrix addition and multiplication operations. Furthermore, within full-precision LLMs, loading weights from DRAM to an on-chip accelerator memory (e.g. SRAM) incurs higher costs during inference.

A popular suboptimal solution is post-training quantization which can reduce the precision down to 4 bits for better inference. Also enlarging SRAM to improve throughput imposes higher costs than DRAM.

BitNet b1.58

A significant variant of low-bit LLMs is BitNet b1.58 where all weight values are ternary, taking on values {-1, 0, 1}. Its quantization function is absmean in which, the weights are first scaled by their average absolute value and then rounded to the nearest integer ε {-1,0,1}. It is an efficient extension of 1-bit BitNet by including 0 in model parameters. BitNet b1.58 is based upon BitNet architecture (replaces nn.linear with BitLinear). It is highly optimized as it removes floating point multiplication overhead, involving only integer addition (INT-8), and efficiently loads parameters from DRAM. BitNet b1.58 continues to match full-precision Transformer LLM baselines in both perplexity and end-task performance, all while demonstrating cost-effectiveness in terms of latency, memory, throughput, and energy consumption.

Source

BitNet b1.58 uses RMSNorm, SwiGLU, and rotary embedding, removes all biases, and hence can be easily integrated into HuggingFace, vLLM, and llama.cpp.

Can b1.58 LLMs replace Float 16 Models?

The authors of BitNet b1.58 compared it with a reproduced FP16-LLaMA by pretraining both models with the same configurations and evaluated the zero-shot performance on various language tasks. The results reveal that BitNet b1.58 starts to match LLaMA at 3B model size and continues to narrow the performance gap onwards, outperforming full-precision models on perplexity and end-task results. Particularly, a 3.9B BitNet b1.58 was 2.4 times faster and consumed 3.32 times less memory than LLaMA 3B, thus reducing memory and latency costs. This demonstrates that BitNet b1.58 is capable of competing with the full-precision LLMs.

Source

Further experiments revealed that BitNet b1.58 70B was 4.1 times faster and 8.9 times higher throughput capable than the corresponding FP16 LLaMa.

1.58 LLM Experiment Details

Nous Research trained a 1B Bitnet, OLMo-Bitnet-1B on the first 60B tokens of the Dolma dataset. They also trained a standard FP16 OLMo-1B model with the same training configurations to compare performance. The wandb report reveals :

  • OLMo-1B reported slightly better perplexity and cross-entropy loss than OLMo-Bitnet-1B on all dogma dataset subsets including small_dogma_stack, small_pile, small_dogma_crawl, small_c4_en, small-m2d2_s2orc, small-wikitext_103, small-dolma_reddit, small-dolma_books, small_ice, small-dolma_pes2o.
Perplexity and CrossEntropy Loss recorded on small-dogma-stack
  • Similarly, OLMo-1B scores on end-tasks were moderately higher than OLMo-Bitnet-1B.
  • GPU memory consumption of both LLMs was also almost identical.

Training 70M LLama Bitnet

The model was trained for 2 epochs using configurations of NousResearch/Llama-2–7b-hf using dataset abideen/Cosmopedia-100k-pretrain on 1xA100 for almost 2 hours. The training parameters used are below:

  • Learning Rate: 1.5e-3
  • Warmup Steps: 0.1
  • Number of Training Epochs: 2
  • Per Device Training Batch Size: 20
  • dimension: 768
  • logging steps: 100
  • weight decay: 0.01
  • lr_scheduler type: cosine
  • save steps: 0.25
  • fp16: True
  • context length: 256
  • Gradient Accumulation Steps: 2
  • Number of Processes: 1

The training process has also been logged to Weights and Biases. Some of the graphs are shown below:

Small snippet of training code is given below:

### Create the llama model with custom config. Convert it to bitnet.
model = LlamaForCausalLM(config)
convert_to_bitnet(model, copy_weights=False)
model_size = sum(t.numel() for t in model.parameters())
print(f"Model size: {model_size/1000**2:.1f}M parameters")
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

output_path = "./out"
args = TrainingArguments(
output_dir=output_path,
per_device_train_batch_size=BATCH_SIZE,
logging_steps=100,
gradient_accumulation_steps=2,
num_train_epochs=EPOCHS,
weight_decay=0.01,
warmup_steps=0.1,
lr_scheduler_type="cosine",
learning_rate=LEARNING_RATE,
save_steps=0.25,
fp16=True,
report_to="wandb"
)

trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=args,
data_collator=data_collator,
train_dataset=tokenized_data["train"],
)

trainer.train()
trainer.save_model(f"{output_path}/final_model")
folder = "./out/final_model"
api = HfApi()
create_repo(
repo_id = f"{HUGGINGFACE_ID}/{NEW_MODEL}",
repo_type="model",
exist_ok=True,
token=HF_TOKEN,
)

# Upload Model files
api.upload_folder(
folder_path=folder,
repo_type="model",
repo_id=f"{HUGGINGFACE_ID}/{NEW_MODEL}",
token=HF_TOKEN,
)

AutoBitnet

AutoBitnet is an automated tool that allows you to train a BitNet b1.58 on the baselines of any LLaMA architecture on a colab T4 GPU.

AutoBitnet

Special thanks to QueryLoopAI for sponsoring the compute of these experiments.

Also, feel free to drop me a message or:

  1. Connect and reach me on LinkedIn and Twitter
  2. Follow me on 📚 Medium
  3. Subscribe to my 📢 weekly AI newsletter!
  4. Check out my 🤗 Hugging Face

--

--