A Gemini and Gemma tokenizer in Java
It’s always interesting to know how the sausage is made, don’t you think? That’s why, a while ago, I looked at embedding model tokenization, and I implemented a little visualization to see the tokens in a colorful manner. Yet, I was still curious to see how Gemini would tokenize text…
Both LangChain4j Gemini modules (from Vertex AI and from Google AI Labs) can count the tokens included in a piece of text. However, both do so by calling a REST API endpoint method called countTokens
. This is not ideal, as it requires a network hop to get the token counts, thus adding undesired extra latency. Wouldn't it be nicer if we could count tokens locally instead?
Interestingly, both Gemini and the open-weights Gemma models share the same tokenizer and token vocabulary. Also, the tokenizer is based on SentencePiece, which is a tokenizer/detokenizer implementing the byte-pair-encoding (BPE) and unigram language algorithms.
If you look at the Gemma code on HuggingFace, you’ll see a tokenizer.json
file that you can open to see the available tokens in the vocabulary, and a tokenizer.model
file which is some kind of binary compressed variation.
Knowing that the list of tokens supported by Gemini and Gemma is available in those files, and how they are encoded, I was curious to see if I could implement a Java tokenizer that could run locally, rather than calling a remote endpoint.
The SentencePiece
implementation from Google is a C++ library, but I didn't really feel like wrapping it myself with JNI, and fortunately, I discovered that the DJL project had done the JNI wrapping job already.
So let’s see how to tokenize text for Gemini and Gemma, in Java!
Gemini and Gemma tokenization in Java with DJL
First of all, let’s setup the dependency on DJL’s SentencePiece
module (Gradle or Maven):
implementation 'ai.djl.sentencepiece:sentencepiece:0.30.0'
<dependency>
<groupId>ai.djl.sentencepiece</groupId>
<artifactId>sentencepiece</artifactId>
<version>0.30.0</version>
</dependency>
I saved the tokenizer.model
file locally. Note that it's a 4MB file, as Gemini/Gemma have a very large vocabulary of around a quarter million of tokens!
Now, let’s instantiate an SpTokenizer
object that loads this vocabulary file, and tokenize some text:
import ai.djl.sentencepiece.SpTokenizer;
// ...
Path model = Paths.get("src/test/resources/gemini/tokenizer.model");
byte[] modelFileBytes = Files.readAllBytes(model);
try (SpTokenizer tokenizer = new SpTokenizer(modelFileBytes)) {
List<String> tokens = tokenizer.tokenize("""
When integrating an LLM into your application to extend it and \
make it smarter, it's important to be aware of the pitfalls and \
best practices you need to follow to avoid some common problems \
and integrate them successfully. This article will guide you \
through some key best practices that I've come across.
""");
for (String token: tokens) {
System.out.format("[%s]%n", token);
}
System.out.println("Token count: " + tokens.size());
}
When running this Java class, you’ll see the following output:
[When]
[▁integrating]
[▁an]
[▁L]
[LM]
[▁into]
[▁your]
[▁application]
...
Token count: 61
Next steps
Do we need next steps? Yes, why not! My idea is to contribute a tokenizer module to LangChain4j, so that the Vertex AI Gemini and the Google AI Gemini modules can both import it, instead of relying on remote endpoint calls to count tokens.
Originally published at https://glaforge.dev on October 4, 2024.