docs: add canonical tooling corpus (147 files) from Google/HF/frameworks

Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mortdecai
2026-04-18 12:24:48 -04:00
parent 5011059f5d
commit eecebe7ef5
149 changed files with 181297 additions and 0 deletions
+226
View File
@@ -0,0 +1,226 @@
# Google-official Gemma tooling (as of 2026-04-18)
Downloaded corpus of canonical Google / Google-DeepMind Gemma tooling. This
directory mirrors only **upstream-authored** material — no third-party forks,
no community ports, no Ollama-specific content (that lives in
`../../CORPUS_ollama_variants.md`).
Reach for this directory when you need to verify what the canonical code/docs
actually say (prompt tokens, API shapes, supported variants) versus what a
third-party wrapper claims they say.
## Top-line findings (flag for cross-check with rest of corpus)
1. **Canonical JAX/Flax library (`google-deepmind/gemma`) has first-class
Gemma 4 support today** — `gm.nn.Gemma4_E4B()`,
`gm.ckpts.CheckpointPath.GEMMA4_E4B_IT`, and the unified `ChatSampler` /
`ToolSampler` API explicitly lists "2, 3, 3n, 4" as supported. This is the
least-friction Python path if you want the actual reference behavior.
2. **`google/gemma_pytorch` has NO Gemma 4 support** as of last push
(2025-05-30). `scripts/run.py` validates variant in
`['2b', '2b-v2', '7b', '9b', '27b', '1b']`; `scripts/run_multimodal.py` in
`['4b', '12b', '27b_v3']` (all Gemma 3). If someone tells you to "use
the official PyTorch repo" for Gemma 4, they're wrong — it's stale.
3. **`google/gemma.cpp` README says Gemma 2-3 + PaliGemma 2 only** (no Gemma 4
yet), but the repo is actively pushed and explicitly notes active work
happens on the `dev` branch. Worth rechecking `dev` for Gemma 4 support.
4. **Gemma 4 uses a NEW prompt-token syntax** distinct from Gemma 1/2/3:
- Gemma 1/2/3: `<start_of_turn>` / `<end_of_turn>` (symmetric angle brackets)
- Gemma 4: `<|turn>` / `<turn|>` (asymmetric pipe-brackets)
- Plus Gemma-4-new: `<|tool>`/`<tool|>`, `<|tool_call>`/`<tool_call|>`,
`<|tool_response>`/`<tool_response|>`, `<|think|>`,
`<|channel>`/`<channel|>`, `<|image>`/`<image|>`, `<|audio>`/`<audio|>`,
string delimiter `<|"|>`.
- Roles are named directly: `system`, `user`, `model` (no role brackets).
This directly contradicts any chat template built against Gemma 3 tokens.
`CORPUS_tool_calling_format.md` already captures the tool tokens correctly
but does NOT yet document the turn-token change or the thinking tokens.
5. **`gemma.cpp` ships an HTTP API server (`gemma_api_server`) that speaks
the Google Gemini API protocol** (`POST /v1beta/models/<model>:generateContent`,
SSE streaming, session management). This is a canonical Google-built
alternative to Ollama that implements the *real* Gemini REST API locally.
See `gemma-cpp/API_SERVER_README.md`.
6. **Tool use was NOT a trained capability in Gemma 1/2/3** — the DeepMind
`colabs/tool_use.ipynb` explicitly disclaims: *"The Gemma 1, 2 and 3 models
were not specifically trained for tool use. This is more a proof-of-concept
than an officially supported feature."* Gemma 4 is notably absent from that
caveat; the cookbook and blog confirm Gemma 4 has **native function
calling** as a first-class trained capability.
7. **No Gemma 4 technical-report PDF exists yet.** All conventional URLs
(`storage.googleapis.com/deepmind-media/gemma/Gemma4Report.pdf`,
`goo.gle/gemma4report`) return 404/redirect-to-google.com, and the
DeepMind repo README explicitly says "Gemma 4 (Coming soon)". Current
most-authoritative scientific document for the family is the Gemma 3
technical report (arXiv:2503.19786), downloaded here.
8. **Cookbook ships a Gemma-4-specific agentic reference app**
(`apps/Gemma_4_HDP_Agentic_Security/`) demonstrating how to cryptographically
gate Gemma 4's native function calls with Ed25519-signed delegation tokens
(IETF draft `draft-helixar-hdp-agentic-delegation-00`). A more
production-shaped pattern than the toy `tool_use.ipynb`.
## File index
### `deepmind-gemma/` — JAX/Flax reference (the primary Python library)
Upstream: https://github.com/google-deepmind/gemma (`main`, pushed 2026-04-17).
| File | What | Why keep |
|------|------|----------|
| `README.md` | PyPI `gemma` package entry point | Shows canonical `gm.nn.Gemma4_E4B()` API, `ChatSampler` multi-turn/multi-modal example |
| `example_multimodal.py` | Image-captioning fine-tune (Kauldron config) | Canonical end-to-end SFT example; docstring shows exact `<start_of_turn>user / <start_of_image> / <end_of_turn>` interleave for Gemma 3 |
| `example_lora.py` | LoRA fine-tuning recipe | Reach for this if doing PEFT against a Gemma 4 checkpoint |
| `example_dpo.py` | Direct Preference Optimization recipe | Reference for preference-alignment post-training |
| `example_classification.py` | Classification fine-tune | Shows Gemma as a feature extractor |
| `example_sharding.py` | Multi-device sharding | Reference for running >E4B on multi-GPU/TPU |
| `colab_tool_use.ipynb` | Tool-use demo (`ToolSampler`) | Important caveat inside: "not specifically trained for tool use" for Gemma 1/2/3; shows the `gm.tools.Tool` base class API |
| `colab_sampling.ipynb` | Basic inference / chat notebook | Starter-grade canonical sampling example |
Other scripts in the repo (not downloaded, cherry-picked above): `seq2seq.py`, `npo.py`, colabs for `quantization_aware_training`, `sharding`, `tokenizer`, `multimodal`, `finetuning`, `lora_finetuning`, `lora_sampling`. Fetch directly from https://github.com/google-deepmind/gemma/tree/main when needed.
### `gemma-pytorch/` — PyTorch reference (STALE for Gemma 4)
Upstream: https://github.com/google/gemma_pytorch (`main`, pushed 2025-05-30).
| File | What | Why keep |
|------|------|----------|
| `README.md` | Entry-point docs | Only documents up through Gemma 3; no Gemma 4 |
| `run.py` | Text-only inference entry point | Variant whitelist `['2b','2b-v2','7b','9b','27b','1b']` — Gemma 1/2 only |
| `run_multimodal.py` | Multimodal inference entry point | Variant whitelist `['4b','12b','27b_v3']` — Gemma 3 only. Shows exact interleaved `<start_of_turn>user\n`, image, `text, <end_of_turn>\n<start_of_turn>model` pattern |
| `run_xla.py` | TPU/XLA inference | Reference for running Gemma 3 on TPU |
**Do not reach for this repo for Gemma 4 work** until it's updated. Use the
DeepMind JAX lib, Hugging Face `transformers`, or gemma.cpp instead.
### `gemma-cpp/` — C++ reference inference
Upstream: https://github.com/google/gemma.cpp (`main`, pushed 2026-04-17; active dev on `dev` branch).
| File | What | Why keep |
|------|------|----------|
| `README.md` | Project overview, build instructions | States "Gemma 2-3 + PaliGemma 2" in features; Gemma 4 status unclear from `main` — check `dev` branch |
| `API_SERVER_README.md` | HTTP API server that speaks Gemini API protocol | **Most interesting find** — canonical drop-in for apps written against the Gemini API, runs locally. `POST /v1beta/models/<model>:generateContent`, SSE streaming, session KV-cache |
| `examples_README.md` | Pointer to `hello_world` / `simplified_gemma` minimal embedding examples | Starting point for embedding gemma.cpp into your own C++ binary |
### `cookbook/` — Official recipes and end-to-end apps
Upstream: https://github.com/google-gemma/cookbook (`main`, pushed 2026-04-17).
**Note:** `google-gemini/gemma-cookbook` now 301-redirects here; use the
`google-gemma/cookbook` URL going forward.
| File | What | Why keep |
|------|------|----------|
| `README.md` | Cookbook index | Authoritative list of Gemma variants incl. Gemma 4 (E2B / E4B / 26B A4B / 31B), the ecosystem (FunctionGemma, MedGemma, PaliGemma 2, RecurrentGemma, ShieldGemma 2, T5Gemma, TranslateGemma, TxGemma, VaultGemma, EmbeddingGemma) |
| `tutorials_RAG_EmbeddingGemma.ipynb` | RAG with EmbeddingGemma | Currently the only notebook in `tutorials/` — reflects the "latest tested" tier |
| `docs_gemma_chat.ipynb` | Chatbot with Gemma on Keras | Documents the `__START_TURN_USER__ = "<start_of_turn>user\n"` / `__END_TURN__ = "<end_of_turn>\n"` format explicitly; Gemma 2 example, but the class is the canonical illustration of the Gemma 1/2/3 chat template |
| `apps_Gemma4_HDP_AgenticSecurity_README.md` | README for the HDP agentic-security reference app | Gemma-4-specific demo; real production pattern for gating native function calls |
| `apps_Gemma4_HDP_hdp_middleware.py` | Drop-in middleware (`HDPMiddleware.gate()`) | Wraps any Gemma 4 tool executor with Ed25519-signed HDT verification |
| `apps_Gemma4_HDP_AgenticSecurity.ipynb` | Walkthrough notebook | End-to-end: load Gemma 4, issue tokens, gate function calls |
Other cookbook content worth noting (not downloaded — fetch on demand):
- `docs/capabilities/thinking.ipynb` (438 KB) — Gemma 4 thinking-mode notebook
- `docs/capabilities/audio.ipynb` — audio-input capability
- `docs/functiongemma/{finetuning-with-functiongemma,full-function-calling-sequence-with-functiongemma,function-calling-with-hf}.ipynb`**FunctionGemma** is a separate fine-tune on the Gemma 3 270M IT checkpoint specifically for function calling; distinct from Gemma 4's native function calling
- `docs/core/pytorch_gemma.ipynb`, `keras_inference.ipynb`, `huggingface_*.ipynb` — framework-specific recipes
- `docs/integrations/langchain.ipynb` — LangChain integration
- `experiments/{MedGemma,TxGemma}/` and `experiments/[T5Gemma]Example.ipynb`, `[VaultGemma]FineTuning_Inference_Huggingface.ipynb`, etc. — domain-specific Gemma variants
### `docs/` — Canonical ai.google.dev pages (HTML cached)
Verified URLs below; HTML snapshots saved for verbatim preservation.
| File | Source URL |
|------|-----------|
| `ai-google-dev_core.html` | https://ai.google.dev/gemma/docs/core — Gemma 4 overview |
| `ai-google-dev_model_card_4.html` | https://ai.google.dev/gemma/docs/core/model_card_4 — Gemma 4 model card |
| `ai-google-dev_prompt_formatting_gemma4.html` | https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4 — **Gemma 4 prompt tokens (new `<\|turn>`/`<turn\|>` syntax)** |
| `ai-google-dev_function_calling_gemma4.html` | https://ai.google.dev/gemma/docs/capabilities/text/function-calling-gemma4 — **Gemma 4 native function calling spec** |
| `ai-google-dev_formatting.html` | https://ai.google.dev/gemma/docs/formatting — Gemma 1/2/3 prompt format (`<start_of_turn>`/`<end_of_turn>`) |
| `blog_announcement.html` | https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/ — Gemma 4 launch blog, 2026-04-02 |
Other canonical doc URLs (verified to exist, not snapshotted here — visit
directly):
- https://ai.google.dev/gemma/docs — top-level Gemma hub
- https://ai.google.dev/gemma/docs/releases — release history
- https://ai.google.dev/gemma/docs/functiongemma — FunctionGemma variant
- https://ai.google.dev/gemma/docs/core/deploy_to_cloud_run_from_ai_studio — AI Studio → Cloud Run
- https://docs.cloud.google.com/vertex-ai/generative-ai/docs/open-models/use-gemma — Vertex AI
- https://aistudio.google.com — AI Studio
- https://gemma-llm.readthedocs.io — DeepMind JAX lib docs
- https://www.kaggle.com/models/google/gemma-4 — Gemma 4 on Kaggle
- https://huggingface.co/collections/google/gemma-4 — Gemma 4 on HF
### `tech-report/`
| File | What | Source |
|------|------|--------|
| `Gemma3Report.pdf` | **Gemma 3 Technical Report** (arXiv:2503.19786, 2025-03-12) | https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf |
No Gemma 4 technical report exists yet. Probed paths that return 404:
- `Gemma4Report.pdf`, `gemma4-report.pdf`, `Gemma4Report_v1.pdf` under
`storage.googleapis.com/deepmind-media/gemma/`
- `goo.gle/gemma4report` (not configured — redirects to google.com)
DeepMind repo README line: **"Gemma 4 (Coming soon)"**. The Gemma 3 report
remains the most-authoritative Google-DeepMind scientific document for the
family and is the correct citation for architecture fundamentals (Grouped-Query
Attention with post-norm/pre-norm RMSNorm, 5:1 local/global attention layer
interleave, 1024-token local sliding window, RoPE base 1M on global / 10k on
local, SigLIP 400M vision encoder at 896×896 shared across 4B/12B/27B and
frozen during training, SentencePiece tokenizer with 262k vocab shared with
Gemini 2.0, knowledge distillation during pre-training, QAT checkpoints via
5k-step fine-tune for int4/SFP8). Per-variant parameter counts for Gemma 3:
1B = 698M non-embedding + 302M embedding, 4B = 3209M + 675M, 12B = 10759M +
1012M, 27B = 25600M + 1416M.
## Canonical Gemma 4 prompt format (verified 2026-04-18)
**Source:** https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4 and
https://ai.google.dev/gemma/docs/capabilities/text/function-calling-gemma4
Note the `<|turn>` / `<turn|>` are asymmetric — opening has the pipe on the
left, closing has the pipe on the right. Same for all paired delimiters.
```
<|turn>system
<|think|> (optional — activates thinking mode)
<|tool>declaration:FUNCTION_NAME{description:<|"|>...<|"|>,parameters:{properties:{...},required:[...]}}<tool|>
You are a helpful assistant.<turn|>
<|turn>user
What's the weather in Tokyo?<turn|>
<|turn>model
<|channel>thought
...internal reasoning...<channel|>
<|tool_call>call:get_current_weather{location:<|"|>Tokyo, JP<|"|>}<tool_call|>
<|tool_response>response:get_current_weather{temperature:15,weather:<|"|>sunny<|"|>}<tool_response|>
The current weather in Tokyo is 15 degrees and sunny.<turn|>
```
Recommended sampling (per model card, verified):
`temperature=1.0, top_p=0.95, top_k=64`. Tokenizer vocab = **262k** (same as
Gemini 2.0). **BOS token required** — prepend `[BOS]` / set `add_bos=True`.
**Gemma 1/2/3 prompt format (different — for reference):**
```
<start_of_turn>user
[message]<end_of_turn>
<start_of_turn>model
[response]<end_of_turn>
```
Gemma 1/2/3 have no trained tool-use or thinking tokens. PT models end with
`<eos>`; IT models end with `<end_of_turn>`.
## Gemma 4 variants (canonical spec from model card)
| Variant | Params | Active | Context | Multimodal |
|---------|--------|--------|---------|------------|
| Gemma 4 E2B | 2.3B effective (5.1B w/ embeddings), 35 layers | — | 128K | text+image+audio (30s max) |
| Gemma 4 E4B | 4.5B effective (8B w/ embeddings), 42 layers | — | 128K | text+image+audio (30s max) |
| Gemma 4 26B A4B | 25.2B total (MoE), 30 layers | 3.8B | 256K | text+image |
| Gemma 4 31B | 30.7B dense, 60 layers | — | 256K | text+image |
All variants: Apache 2.0, base + instruction-tuned (`-it`), 140+ languages,
native function calling, native structured JSON output. Vision encoder = 150M
(E2B/E4B) or 550M (26B/31B). Image resolution token budgets: 70, 140, 280,
560, 1120. Released 2026-04-02.
## Fetched using
All files fetched via `curl -sL` from `raw.githubusercontent.com` on
2026-04-18. Repos enumerated via the GitHub API
(`https://api.github.com/repos/<owner>/<repo>/contents/<path>`). Google docs
pages fetched via WebFetch tool. No GitHub auth needed for public raw files
(unauthenticated rate limit = 60 req/hr, sufficient for this task).
@@ -0,0 +1,80 @@
# Welcome to the Gemma Cookbook
This is a collection of guides and examples for [Google Gemma](https://ai.google.dev/gemma/).
> **Disclaimer:** Gemma is a family of developer-focused models built by Google DeepMind. This cookbook is a collection of guides and examples for Google Gemma. Please keep in mind that Gemma is an open model and can hallucinate as you build on examples in this cookbook.
## Repository Structure
* [**Tutorials**](tutorials/): The latest tested notebooks for Gemma models and variants.
* [**Apps**](apps/): Full-stack demos and complex end-to-end use cases.
* [**Experiments**](experiments/): Research-focused model notebooks, including [TxGemma](experiments/TxGemma) and [MedGemma](experiments/MedGemma).
* [**Responsible**](responsible/): Notebooks for responsible AI development.
* [**Docs**](docs/): Core documentation, capabilities, and technical guides.
* [**Archive**](.archive/): All older notebooks and historical examples.
## Get started with the Gemma models
Gemma is a family of lightweight, generative artificial intelligence (AI) open models, built from the same research and technology used to create the Gemini models. The Gemma model family includes:
* Gemma\
The core models of the Gemma family.
* [Gemma](https://ai.google.dev/gemma/docs/core/model_card)\
For a variety of text generation tasks and can be further tuned for specific use cases
* [Gemma 2](https://ai.google.dev/gemma/docs/core/model_card_2)\
Higher-performing and more efficient, available in 2B, 9B, 27B parameter sizes
* [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3)\
Longer context window and handling text and image input, available in 1B, 4B, 12B, and 27B parameter sizes
* [Gemma 3n](https://ai.google.dev/gemma/docs/gemma-3n/model_card) \
Designed for efficient execution on low-resource devices. Handling text, image, video, and audio input, available in E2B and E4B parameter sizes
* [Gemma 4](https://ai.google.dev/gemma/docs/core/model_card_4)\
Well-suited for reasoning, agentic workflows, coding, and multimodal understanding, available in E2B, E4B, 26B A4B, and 31B parameter sizes.
* Gemma variants
* [CodeGemma](https://ai.google.dev/gemma/docs/codegemma)\
Fine-tuned for a variety of coding tasks
* [DataGemma](https://ai.google.dev/gemma/docs/datagemma)\
Fine-tuned for using Data Commons to address AI hallucinations
* [FunctionGemma](https://ai.google.dev/gemma/docs/functiongemma)\
Fine-tuned on Gemma 3 270M IT checkpoint for function calling
* [MedGemma](https://developers.google.com/health-ai-developer-foundations/medgemma)
The MedGemma collection contains Google's most capable open models for medical text and image comprehension, built on Gemma 3. Developers can use MedGemma to accelerate building healthcare-based AI applications. MedGemma comes in two variants: a 4B multimodal version and a 27B text-only version.
* [PaliGemma](https://ai.google.dev/gemma/docs/paligemma/model-card)\
Vision Language Model\
For a deeper analysis of images and provide useful insights
* [PaliGemma 2](https://ai.google.dev/gemma/docs/paligemma/model-card-2)\
VLM which incorporates the capabilities of the Gemma 2 models
* [RecurrentGemma](https://ai.google.dev/gemma/docs/recurrentgemma)\
Based on [Griffin](https://arxiv.org/abs/2402.19427) architecture\
For a variety of text generation tasks
* [ShieldGemma](https://ai.google.dev/gemma/docs/shieldgemma/model_card)\
Fine-tuned for evaluating the safety of text prompt input and text output responses against a set of defined safety policies
* [ShieldGemma 2](https://ai.google.dev/gemma/docs/shieldgemma/model_card_2)\
Fine-tuned on Gemma 3 4B IT checkpoint for image safety classification
* [T5Gemma](https://deepmind.google/models/gemma/t5gemma)\
A collection of encoder-decoder models that provide a strong quality-inference efficiency tradeoff
* [TranslateGemma](https://huggingface.co/collections/google/translategemma)\
A collection of open model designed to handle translation tasks across 55 languages
* [TxGemma](https://deepmind.google/models/gemma/txgemma)\
A collection of open models designed to improve the efficiency of therapeutic development
* [VaultGemma](https://deepmind.google/models/gemma/vaultgemma)\
An open model trained from the ground up using differential privacy to prevent memorization and leaking of training data examples
You can find the Gemma models on the Hugging Face Hub, Kaggle, Google Cloud Vertex AI Model Garden, and [ai.nvidia.com](https://ai.nvidia.com).
## Additional Resources
* [MedGemma on Google-Health](https://github.com/Google-Health/medgemma/tree/main/notebooks) : Google-Health has additional notebooks for using MedGemma
* [Gemma on Google Cloud](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/open-models) : GCP open models has additional notebooks for using Gemma
## Get help
Ask a Gemma cookbook-related question on the [developer forum](https://discuss.ai.google.dev/c/gemma/10), or open an [issue](https://github.com/google-gemini/gemma-cookbook/issues) on GitHub.
## Wish list
If you want to see additional cookbooks implemented for specific features/integrations, please open a new issue with [“Feature Request” template](https://github.com/google-gemini/gemma-cookbook/issues/new?template=feature_request.yml).
If you want to make contributions to the Gemma Cookbook project, you are welcome to pick any idea in the [“Wish List”](https://github.com/google-gemini/gemma-cookbook/labels/wishlist) and implement it.
## Contributing
Contributions are always welcome. Please read [contributing](https://github.com/google-gemini/gemma-cookbook/blob/main/CONTRIBUTING.md) before implementation.
Thank you for developing with Gemma! Were excited to see what you create.
## Translation of this repository
* [Traditional Chinese](https://github.com/doggy8088/gemma-cookbook)
* [Simplified Chinese](https://github.com/xiaoxiong1006/gemma-cookbook)
@@ -0,0 +1,526 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "colab-badge"
},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/apps/Gemma_4_HDP_Agentic_Security/Gemma_4_HDP_Agentic_Security.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "byline"
},
"source": [
"# Securing Gemma 4 Agentic Workflows with HDP\n",
"\n",
"**Author:** Asiri Dalugoda, Helixar Limited ([@asiridalugoda](https://github.com/asiridalugoda)) | [helixar.ai](https://helixar.ai)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gpu-instructions"
},
"source": [
"## Before you begin\n",
"\n",
"This notebook requires a GPU runtime. To enable GPU in Colab:\n",
"1. Go to **Runtime → Change runtime type**\n",
"2. Set **Hardware accelerator** to **GPU** (T4 is sufficient for E4B)\n",
"3. Click **Save**\n",
"\n",
"You will also need a **Hugging Face token** to download Gemma 4 (gated model):\n",
"1. Go to [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
"2. Create a token with **Read** access\n",
"3. Accept the Gemma 4 model license at [huggingface.co/google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)\n",
"4. Run the cell below to authenticate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hf-login"
},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "overview"
},
"source": [
"# Securing Gemma 4 Agentic Workflows with HDP\n",
"\n",
"**Human Delegation Provenance (HDP)** is an open protocol that adds a cryptographic chain-of-custody to AI agent function calls — ensuring every tool invocation can be traced back to an authorized human principal.\n",
"\n",
"This notebook demonstrates how to integrate HDP with Gemma 4's native function-calling capability to:\n",
"\n",
"- **Verify** that Gemma 4's function calls were authorized by a human principal before execution\n",
"- **Classify** actions by irreversibility (read-only → irreversible → physical actuation)\n",
"- **Block** unauthorized or out-of-scope tool calls at the middleware layer\n",
"- **Audit** every decision with a pre-execution log\n",
"\n",
"This is particularly relevant for Gemma 4 deployments on edge devices (Jetson Nano, Raspberry Pi) where the model may be directing physical actuators offline with no out-of-band authorization check.\n",
"\n",
"**References:**\n",
"- HDP IETF draft: [draft-helixar-hdp-agentic-delegation-00](https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/)\n",
"- HDP-P (physical AI agents): [DOI 10.5281/ZENODO.19332440](https://doi.org/10.5281/ZENODO.19332440)\n",
"- Helixar: [helixar.ai](https://helixar.ai)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b3600ee25c8e"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7a80251f52b3"
},
"outputs": [],
"source": [
"!pip install -q transformers torch cryptography"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ed80fe18f255"
},
"outputs": [],
"source": [
"# Download the middleware\n",
"!wget -q https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/apps/Gemma_4_HDP_Agentic_Security/hdp_middleware.py\n",
"\n",
"from hdp_middleware import (\n",
" HDPDelegationToken,\n",
" HDPMiddleware,\n",
" IrreversibilityClass,\n",
" DEFAULT_TOOL_CLASS_MAP,\n",
")\n",
"from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey\n",
"import json"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e88bdc7b7265"
},
"source": [
"## 1. Load Gemma 4\n",
"\n",
"We use the 4B Effective model for this demo. For production agentic deployments, the 26B MoE or 31B Dense models are recommended."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1e4e7779806d"
},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"\n",
"# For edge/robotics use cases: swap to google/gemma-4-E2B-it\n",
"MODEL_ID = \"google/gemma-4-E4B-it\"\n",
"\n",
"pipe = pipeline(\n",
" \"text-generation\",\n",
" model=MODEL_ID,\n",
" device_map=\"auto\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d91e36cfb0b2"
},
"source": [
"## 2. Define Tools\n",
"\n",
"Gemma 4 uses structured JSON function-calling. We define a tool set spanning different IrreversibilityClasses to demonstrate the middleware's classification behaviour."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1becdb52e7f8"
},
"outputs": [],
"source": [
"TOOLS = [\n",
" {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Get the current weather for a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\"type\": \"string\", \"description\": \"City name\"}\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" },\n",
" {\n",
" \"name\": \"send_email\",\n",
" \"description\": \"Send an email to a recipient.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"to\": {\"type\": \"string\"},\n",
" \"subject\": {\"type\": \"string\"},\n",
" \"body\": {\"type\": \"string\"}\n",
" },\n",
" \"required\": [\"to\", \"subject\", \"body\"]\n",
" }\n",
" },\n",
" {\n",
" \"name\": \"delete_file\",\n",
" \"description\": \"Permanently delete a file by path.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"path\": {\"type\": \"string\"}\n",
" },\n",
" \"required\": [\"path\"]\n",
" }\n",
" },\n",
" {\n",
" \"name\": \"actuate_robot_arm\",\n",
" \"description\": \"Command a robot arm to move to a target position.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"joint_angles\": {\"type\": \"array\", \"items\": {\"type\": \"number\"}},\n",
" \"force_limit_n\": {\"type\": \"number\"}\n",
" },\n",
" \"required\": [\"joint_angles\"]\n",
" }\n",
" }\n",
"]\n",
"\n",
"# Tools indexed by name for lookup\n",
"TOOL_REGISTRY = {t[\"name\"]: t for t in TOOLS}\n",
"print(f\"Registered {len(TOOLS)} tools\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "722948b00a92"
},
"source": [
"## 3. Issue an HDP Delegation Token\n",
"\n",
"The human principal generates an Ed25519 keypair and issues an HDT that specifies:\n",
"- Which tools the agent is permitted to call\n",
"- The maximum IrreversibilityClass the agent can act on\n",
"- The token's lifetime"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b0622c68dfa5"
},
"outputs": [],
"source": [
"# Human principal generates their signing keypair\n",
"# In production: loaded from secure key storage (HSM, OS keychain, etc.)\n",
"principal_private_key = Ed25519PrivateKey.generate()\n",
"principal_public_key = principal_private_key.public_key()\n",
"\n",
"# Issue an HDT authorizing the Gemma 4 agent to call weather queries\n",
"# and send emails (Class 0 and Class 2), but NOT delete files or actuate hardware\n",
"token = HDPDelegationToken.issue(\n",
" principal_id=\"alice@example.com\",\n",
" agent_id=\"gemma4-agent-01\",\n",
" scope=[\"get_weather\", \"send_email\"],\n",
" max_class=IrreversibilityClass.CLASS_2,\n",
" ttl_seconds=3600,\n",
" private_key=principal_private_key,\n",
")\n",
"\n",
"print(json.dumps(token.to_dict(), indent=2))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e206f950f4bc"
},
"source": [
"## 4. Initialise the HDP Middleware\n",
"\n",
"The middleware takes the principal's **public key** only — it verifies but cannot issue tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e24676f528bf"
},
"outputs": [],
"source": [
"audit_log = []\n",
"\n",
"# Confirmation callback for Class 2 (irreversible) actions.\n",
"# In production: this would invoke a push notification, SMS OTP,\n",
"# or hardware confirmation device to the human principal.\n",
"def require_human_confirmation(tool_name: str, parameters: dict) -> bool:\n",
" print(f\"\\n⚠️ Class 2 action requested: {tool_name}\")\n",
" print(f\" Parameters: {json.dumps(parameters, indent=4)}\")\n",
" response = input(\" Confirm? [y/N]: \").strip().lower()\n",
" return response == \"y\"\n",
"\n",
"middleware = HDPMiddleware(\n",
" public_key=principal_public_key,\n",
" tool_class_map=DEFAULT_TOOL_CLASS_MAP,\n",
" confirmation_callback=require_human_confirmation,\n",
" audit_log=audit_log,\n",
")\n",
"\n",
"print(\"HDP middleware initialised.\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "72d56542eba0"
},
"source": [
"## 5. Gemma 4 Function Call → HDP Gate → Tool Execution\n",
"\n",
"This is the core integration pattern. Every function call Gemma 4 generates is passed through `middleware.gate()` before being forwarded to tool execution."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "da20bc191e71"
},
"outputs": [],
"source": [
"# Simulated Gemma 4 function call outputs\n",
"# In production these come from parsing Gemma 4's structured JSON output\n",
"gemma_function_calls = [\n",
" # ✅ Should ALLOW — Class 0, in scope\n",
" {\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}},\n",
"\n",
" # ⚠️ Should CONFIRM then ALLOW — Class 2, in scope\n",
" {\"name\": \"send_email\", \"parameters\": {\n",
" \"to\": \"bob@example.com\",\n",
" \"subject\": \"Weekly report\",\n",
" \"body\": \"Please find attached.\"\n",
" }},\n",
"\n",
" # ❌ Should BLOCK — Class 2, NOT in HDT scope\n",
" {\"name\": \"delete_file\", \"parameters\": {\"path\": \"/data/important.csv\"}},\n",
"\n",
" # ❌ Should BLOCK — Class 3, physical actuation\n",
" {\"name\": \"actuate_robot_arm\", \"parameters\": {\n",
" \"joint_angles\": [0.0, -1.57, 0.0, -1.57, 0.0, 0.0],\n",
" \"force_limit_n\": 50.0\n",
" }},\n",
"]\n",
"\n",
"print(\"=\" * 60)\n",
"print(\"HDP VERIFICATION RESULTS\")\n",
"print(\"=\" * 60)\n",
"\n",
"for call in gemma_function_calls:\n",
" result = middleware.gate(call, token)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "be0d0dd05bce"
},
"source": [
"## 6. Audit Log\n",
"\n",
"Every decision is logged pre-execution. This is the HDP audit trail — a cryptographically linked record of what was authorized, by whom, and when."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e6dbab6d88d1"
},
"outputs": [],
"source": [
"print(\"\\nAUDIT LOG\")\n",
"print(\"-\" * 60)\n",
"for i, entry in enumerate(audit_log):\n",
" status = \"✅ ALLOWED\" if entry.allowed else \"❌ BLOCKED\"\n",
" print(f\"{i+1}. {status} | {entry.tool_name} | {entry.action_class.name} | {entry.reason}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bcadcb7040db"
},
"source": [
"## 7. Token Expiry and Scope Violation Demo\n",
"\n",
"Demonstrate that expired tokens and out-of-scope calls are blocked regardless of the action class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "deb2e3b6b20e"
},
"outputs": [],
"source": [
"import time\n",
"\n",
"# Issue a token that's already expired\n",
"expired_token = HDPDelegationToken.issue(\n",
" principal_id=\"alice@example.com\",\n",
" agent_id=\"gemma4-agent-01\",\n",
" scope=[\"get_weather\"],\n",
" max_class=IrreversibilityClass.CLASS_0,\n",
" ttl_seconds=-1, # expired immediately\n",
" private_key=principal_private_key,\n",
")\n",
"\n",
"print(\"Testing expired token:\")\n",
"middleware.gate({\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}}, expired_token)\n",
"\n",
"print(\"\\nTesting call outside HDT scope:\")\n",
"middleware.gate({\"name\": \"delete_file\", \"parameters\": {\"path\": \"/etc/passwd\"}}, token)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b8f4acddb6fa"
},
"source": [
"## 8. Edge / Robotics Deployment (HDP-P)\n",
"\n",
"For Gemma 4 E2B/E4B running on Jetson Nano or Raspberry Pi and directing physical actuators, use the HDP-P extension. The key additions are:\n",
"\n",
"- **Embodiment context** — bind the token to a specific hardware ID\n",
"- **Policy attestation** — hash the deployed model weights into the token\n",
"- **Fleet delegation constraints** — prevent lateral movement across robot fleet\n",
"- **Pre-execution logging** — write audit records *before* actuator commands are issued\n",
"\n",
"See the [HDP-P specification](https://doi.org/10.5281/ZENODO.19332440) for the full EDT extension structure."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fcf7b451d175"
},
"outputs": [],
"source": [
"# Minimal HDP-P Embodied Delegation Token (EDT) extension example\n",
"# This shows how to attach physical constraints to an HDT\n",
"\n",
"hdp_p_extension = {\n",
" \"hdp-p\": {\n",
" \"version\": \"0.1\",\n",
" \"embodiment\": {\n",
" \"type\": \"mobile\",\n",
" \"platform\": \"raspberry-pi-5\",\n",
" \"hardware_id\": \"rpi-serial-XXXX\", # TPM-attested in production\n",
" \"workspace\": \"lab-zone-a\"\n",
" },\n",
" \"action_scope\": {\n",
" \"permitted_actions\": [\"move_base\", \"read_sensor\"],\n",
" \"excluded_zones\": [\"human-workspace\"],\n",
" \"force_limit_n\": 10.0,\n",
" \"max_velocity_ms\": 0.5\n",
" },\n",
" \"irreversibility\": {\n",
" \"max_class\": 1, # Class 1 max for this token\n",
" \"class2_requires_confirmation\": True,\n",
" \"class3_prohibited\": True\n",
" },\n",
" \"policy_attestation\": {\n",
" \"policy_hash\": \"sha256:abc123...\", # SHA-256 of deployed model weights\n",
" \"training_run_id\": \"gemma4-e2b-it\",\n",
" \"sim_validated\": True\n",
" },\n",
" \"delegation_scope\": {\n",
" \"fleet_delegation_permitted\": False, # No lateral movement\n",
" \"max_delegation_depth\": 0\n",
" }\n",
" }\n",
"}\n",
"\n",
"print(\"HDP-P EDT extension structure:\")\n",
"print(json.dumps(hdp_p_extension, indent=2))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b0af7c701dfc"
},
"source": [
"## Summary\n",
"\n",
"| Layer | What it solves | Tool |\n",
"|---|---|---|\n",
"| Gemma 4 function calling | Model generates structured tool calls | `pipeline(\"text-generation\")` |\n",
"| HDP middleware | Was this call authorized by a human? | `HDPMiddleware.gate()` |\n",
"| HDP-P EDT extension | Is this physical action within delegated bounds? | `hdp_p_extension` |\n",
"| Audit log | Pre-execution record of every decision | `audit_log` |\n",
"\n",
"The full HDP specification (IETF draft), HDP-P companion paper, TypeScript SDK, and Python bindings are available at:\n",
"\n",
"- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/\n",
"- **HDP-P paper:** https://doi.org/10.5281/ZENODO.19332440\n",
"- **GitHub:** https://github.com/Helixar-AI\n",
"- **Site:** https://helixar.ai"
]
}
],
"metadata": {
"colab": {
"name": "Gemma_4_HDP_Agentic_Security.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@@ -0,0 +1,75 @@
# Gemma 4 + HDP: Securing Agentic Function Calls
This example demonstrates how to integrate the **Human Delegation Provenance (HDP)** protocol with **Gemma 4's native function-calling** to cryptographically verify that every tool invocation was authorized by a human principal before execution.
## The problem
Gemma 4 is purpose-built for agentic workflows. Its native function-calling lets it autonomously call tools and APIs across multi-step plans — on anything from a cloud workstation to a Raspberry Pi running a robot offline.
This creates a gap: when Gemma 4 generates a function call, there is no verifiable record that a human principal authorized that specific action. An injected prompt, a compromised system prompt, or a lateral pivot from another agent can trigger function calls that are indistinguishable from legitimate requests at the tool interface.
HDP closes this gap.
## What HDP does
HDP (IETF draft: `draft-helixar-hdp-agentic-delegation-00`) provides:
- **Ed25519-signed Delegation Tokens (HDTs)** issued by a human principal
- **Scope constraints** — which tools the agent is permitted to call
- **Irreversibility classification** (Class 03) — from read-only to physical actuation
- **Pre-execution verification** — the middleware gate runs *before* any tool executes
- **Audit log** — a tamper-evident record of every authorization decision
For Gemma 4 on **edge devices directing physical actuators** (Jetson Nano, Raspberry Pi + robot arm), the HDP-P companion specification adds embodiment constraints, policy attestation, and fleet delegation controls.
## Files
| File | Description |
|---|---|
| `Gemma_4_HDP_Agentic_Security.ipynb` | Full walkthrough notebook — load Gemma 4, issue tokens, gate function calls |
| `hdp_middleware.py` | Drop-in middleware — `HDPMiddleware.gate()` wraps any Gemma 4 tool executor |
## Quick start
```python
from hdp_middleware import HDPDelegationToken, HDPMiddleware, IrreversibilityClass
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
# Human principal issues a delegation token
private_key = Ed25519PrivateKey.generate()
token = HDPDelegationToken.issue(
principal_id="alice@example.com",
agent_id="gemma4-agent-01",
scope=["get_weather", "send_email"],
max_class=IrreversibilityClass.CLASS_2,
ttl_seconds=3600,
private_key=private_key,
)
# Middleware verifies every Gemma 4 function call before execution
middleware = HDPMiddleware(public_key=private_key.public_key())
result = middleware.gate(
function_call={"name": "send_email", "parameters": {"to": "bob@example.com", ...}},
token=token,
)
if result.allowed:
execute_tool(function_call)
```
## Irreversibility classes
| Class | Definition | Authorization |
|---|---|---|
| 0 | Fully reversible — reads, queries | HDT sufficient |
| 1 | Reversible with effort — writes, moves | HDT sufficient |
| 2 | Irreversible — send, delete, publish | HDT + principal confirmation |
| 3 | Irreversible + potentially harmful — physical actuation | Dual-principal required (HDP-P) |
## References
- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/
- **Zenodo DOI:** https://doi.org/10.5281/zenodo.19332023
- **HDP-P (physical AI):** https://doi.org/10.5281/ZENODO.19332440
- **Helixar:** https://helixar.ai
@@ -0,0 +1,390 @@
"""
HDP (Human Delegation Provenance) middleware for Gemma 4 function calling.
Intercepts Gemma 4 function call outputs and verifies that a valid HDP
Delegation Token (HDT) authorizes the requested action before forwarding
to the tool execution layer.
Reference: draft-helixar-hdp-agentic-delegation-00
https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/
DOI: 10.5281/zenodo.19332023
For physical AI agents (robots, edge devices), see HDP-P:
DOI: 10.5281/ZENODO.19332440
"""
import json
import time
import base64
import hashlib
import hmac
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional, Callable, Any
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
Ed25519PublicKey,
)
from cryptography.hazmat.primitives.serialization import (
Encoding,
PublicFormat,
PrivateFormat,
NoEncryption,
)
from cryptography.exceptions import InvalidSignature
# ---------------------------------------------------------------------------
# Irreversibility Classes (HDP-P §4.2)
# ---------------------------------------------------------------------------
class IrreversibilityClass(IntEnum):
"""
Classification of physical action reversibility (HDP-P §4.2).
For digital-only Gemma 4 deployments, all tool calls are Class 0 or 1.
For edge/robotics deployments (Jetson Nano, Raspberry Pi + actuators),
Class 2 and 3 require explicit pre-execution confirmation.
"""
CLASS_0 = 0 # Fully reversible — read-only, query, observe
CLASS_1 = 1 # Reversible with effort — write, create, move
CLASS_2 = 2 # Irreversible under normal conditions — delete, send, publish
CLASS_3 = 3 # Irreversible and potentially harmful — physical actuation
# Default tool → irreversibility class mapping.
# Deployments should override this for their specific tool set.
DEFAULT_TOOL_CLASS_MAP: dict[str, IrreversibilityClass] = {
# Class 0 — safe reads
"get_weather": IrreversibilityClass.CLASS_0,
"search_web": IrreversibilityClass.CLASS_0,
"read_file": IrreversibilityClass.CLASS_0,
"query_database": IrreversibilityClass.CLASS_0,
# Class 1 — reversible writes
"write_file": IrreversibilityClass.CLASS_1,
"create_record": IrreversibilityClass.CLASS_1,
"move_object": IrreversibilityClass.CLASS_1,
# Class 2 — irreversible digital actions
"send_email": IrreversibilityClass.CLASS_2,
"delete_file": IrreversibilityClass.CLASS_2,
"publish_post": IrreversibilityClass.CLASS_2,
"execute_transaction": IrreversibilityClass.CLASS_2,
# Class 3 — physical actuation (HDP-P scope)
"actuate_robot_arm": IrreversibilityClass.CLASS_3,
"command_vehicle": IrreversibilityClass.CLASS_3,
"dispense_fluid": IrreversibilityClass.CLASS_3,
"apply_force": IrreversibilityClass.CLASS_3,
}
# ---------------------------------------------------------------------------
# HDP Delegation Token (HDT)
# ---------------------------------------------------------------------------
@dataclass
class HDPDelegationToken:
"""
Simplified HDT structure derived from draft-helixar-hdp-agentic-delegation-00.
In production, HDTs are JOSE/JWT tokens signed with Ed25519.
This implementation provides the core claims structure and verification logic.
Claims:
iss — issuer (human principal identifier)
sub — subject (agent being delegated to)
iat — issued at (unix timestamp)
exp — expiry (unix timestamp)
scope — list of permitted tool names or wildcard patterns
max_irreversibility_class — ceiling on action class (03)
delegation_depth — remaining delegation hops permitted
nonce — replay-attack prevention
"""
iss: str
sub: str
iat: int
exp: int
scope: list[str]
max_irreversibility_class: IrreversibilityClass
delegation_depth: int = 1
nonce: str = ""
_signature: bytes = field(default=b"", repr=False)
_public_key: Optional[Ed25519PublicKey] = field(default=None, repr=False)
@classmethod
def issue(
cls,
principal_id: str,
agent_id: str,
scope: list[str],
max_class: IrreversibilityClass,
ttl_seconds: int = 3600,
delegation_depth: int = 1,
private_key: Optional[Ed25519PrivateKey] = None,
) -> "HDPDelegationToken":
"""
Issue a new HDT signed by the human principal's Ed25519 private key.
Args:
principal_id: Human principal identifier (e.g. "alice@example.com")
agent_id: Agent being delegated to (e.g. "gemma4-agent-01")
scope: List of permitted tool names. Use ["*"] for unrestricted.
max_class: Maximum IrreversibilityClass this token permits.
ttl_seconds: Token lifetime in seconds.
delegation_depth: How many times this token can be re-delegated.
private_key: Ed25519 private key for signing. Generated if None.
"""
now = int(time.time())
nonce = base64.urlsafe_b64encode(
hashlib.sha256(f"{principal_id}{now}".encode()).digest()[:16]
).decode()
token = cls(
iss=principal_id,
sub=agent_id,
iat=now,
exp=now + ttl_seconds,
scope=scope,
max_irreversibility_class=max_class,
delegation_depth=delegation_depth,
nonce=nonce,
)
if private_key is None:
private_key = Ed25519PrivateKey.generate()
token._public_key = private_key.public_key()
token._signature = private_key.sign(token._canonical_bytes())
return token
def _canonical_bytes(self) -> bytes:
"""Deterministic serialisation for signing/verification."""
payload = {
"iss": self.iss,
"sub": self.sub,
"iat": self.iat,
"exp": self.exp,
"scope": sorted(self.scope),
"max_irreversibility_class": int(self.max_irreversibility_class),
"delegation_depth": self.delegation_depth,
"nonce": self.nonce,
}
return json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
def verify(self, public_key: Ed25519PublicKey) -> bool:
"""Verify the token's Ed25519 signature."""
try:
public_key.verify(self._signature, self._canonical_bytes())
return True
except InvalidSignature:
return False
def is_expired(self) -> bool:
return int(time.time()) > self.exp
def permits_tool(self, tool_name: str) -> bool:
"""Check whether this token's scope covers the requested tool."""
if "*" in self.scope:
return True
return tool_name in self.scope
def permits_class(self, action_class: IrreversibilityClass) -> bool:
return action_class <= self.max_irreversibility_class
def to_dict(self) -> dict:
return {
"iss": self.iss,
"sub": self.sub,
"iat": self.iat,
"exp": self.exp,
"scope": self.scope,
"max_irreversibility_class": int(self.max_irreversibility_class),
"delegation_depth": self.delegation_depth,
"nonce": self.nonce,
}
# ---------------------------------------------------------------------------
# Verification result
# ---------------------------------------------------------------------------
@dataclass
class VerificationResult:
allowed: bool
reason: str
tool_name: str
action_class: IrreversibilityClass
token_iss: Optional[str] = None
requires_confirmation: bool = False
def __str__(self) -> str:
status = "ALLOWED" if self.allowed else "BLOCKED"
conf = " [CONFIRMATION REQUIRED]" if self.requires_confirmation else ""
return (
f"[HDP] {status}{conf} — tool={self.tool_name} "
f"class={self.action_class.name} reason={self.reason}"
)
# ---------------------------------------------------------------------------
# HDP Middleware
# ---------------------------------------------------------------------------
class HDPMiddleware:
"""
HDP verification gate for Gemma 4 function calls.
Sits between Gemma 4's function-call output and the tool execution layer.
For each function call Gemma 4 generates, this middleware:
1. Parses the tool name from the function call.
2. Looks up its IrreversibilityClass.
3. Verifies the attached HDT (signature, expiry, scope, class ceiling).
4. For Class 2 actions, invokes the confirmation callback.
5. Blocks Class 3 actions unless explicitly pre-authorized with
dual verification (HDP-P §5.4).
6. Logs all decisions before forwarding or blocking.
Usage:
middleware = HDPMiddleware(
public_key=principal_public_key,
tool_class_map=DEFAULT_TOOL_CLASS_MAP,
confirmation_callback=my_confirmation_fn,
)
# Wrap your tool executor:
result = middleware.gate(
function_call=gemma_output, # {"name": "...", "parameters": {...}}
token=hdp_token,
)
if result.allowed:
output = execute_tool(function_call)
"""
def __init__(
self,
public_key: Ed25519PublicKey,
tool_class_map: dict[str, IrreversibilityClass] = None,
confirmation_callback: Optional[Callable[[str, dict], bool]] = None,
default_class: IrreversibilityClass = IrreversibilityClass.CLASS_1,
audit_log: Optional[list] = None,
):
"""
Args:
public_key: Principal's Ed25519 public key for HDT verification.
tool_class_map: Mapping of tool names to IrreversibilityClass.
Defaults to DEFAULT_TOOL_CLASS_MAP.
confirmation_callback: Called for Class 2 actions. Receives
(tool_name, parameters) and returns bool.
If None, Class 2 actions are blocked.
default_class: Class assigned to unknown tools. Defaults to CLASS_1.
audit_log: Optional list to append VerificationResult records to.
"""
self.public_key = public_key
self.tool_class_map = tool_class_map or DEFAULT_TOOL_CLASS_MAP
self.confirmation_callback = confirmation_callback
self.default_class = default_class
self.audit_log = audit_log if audit_log is not None else []
def classify(self, tool_name: str) -> IrreversibilityClass:
"""Return the IrreversibilityClass for a tool name."""
return self.tool_class_map.get(tool_name, self.default_class)
def gate(
self,
function_call: dict,
token: HDPDelegationToken,
) -> VerificationResult:
"""
Main verification gate. Call this for every Gemma 4 function call.
Args:
function_call: Gemma 4 function call dict:
{"name": "tool_name", "parameters": {...}}
token: HDPDelegationToken issued by the human principal.
Returns:
VerificationResult — check .allowed before executing the tool.
"""
tool_name = function_call.get("name", "")
parameters = function_call.get("parameters", {})
action_class = self.classify(tool_name)
def _block(reason: str) -> VerificationResult:
result = VerificationResult(
allowed=False,
reason=reason,
tool_name=tool_name,
action_class=action_class,
token_iss=token.iss if token else None,
)
self.audit_log.append(result)
print(result)
return result
def _allow(reason: str, requires_confirmation: bool = False) -> VerificationResult:
result = VerificationResult(
allowed=True,
reason=reason,
tool_name=tool_name,
action_class=action_class,
token_iss=token.iss,
requires_confirmation=requires_confirmation,
)
self.audit_log.append(result)
print(result)
return result
# ── 1. Token presence ───────────────────────────────────────────────
if token is None:
return _block("no HDT present")
# ── 2. Expiry ───────────────────────────────────────────────────────
if token.is_expired():
return _block("HDT expired")
# ── 3. Signature ────────────────────────────────────────────────────
if not token.verify(self.public_key):
return _block("HDT signature invalid")
# ── 4. Scope ────────────────────────────────────────────────────────
if not token.permits_tool(tool_name):
return _block(f"tool '{tool_name}' not in HDT scope")
# ── 5. Irreversibility class ceiling ────────────────────────────────
if not token.permits_class(action_class):
return _block(
f"action class {action_class.name} exceeds HDT ceiling "
f"{token.max_irreversibility_class.name}"
)
# ── 6. Class 3 — always blocked without explicit dual verification ──
if action_class == IrreversibilityClass.CLASS_3:
# In production: implement dual-principal confirmation (HDP-P §5.4)
return _block(
"Class 3 physical action requires dual-principal confirmation "
"(HDP-P §5.4) — not implemented in this middleware instance"
)
# ── 7. Class 2 — confirmation callback required ─────────────────────
if action_class == IrreversibilityClass.CLASS_2:
if self.confirmation_callback is None:
return _block(
"Class 2 action requires confirmation callback — "
"none configured"
)
confirmed = self.confirmation_callback(tool_name, parameters)
if not confirmed:
return _block("Class 2 action — confirmation denied by principal")
return _allow("Class 2 confirmed by principal", requires_confirmation=True)
# ── 8. Class 0 / 1 — allow ─────────────────────────────────────────
return _allow(f"HDT valid, scope and class verified")
def gate_batch(
self,
function_calls: list[dict],
token: HDPDelegationToken,
) -> list[VerificationResult]:
"""Verify a list of function calls. Returns one result per call."""
return [self.gate(fc, token) for fc in function_calls]
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,925 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "-u7xRR3DeFXz"
},
"source": [
"##### Copyright 2026 Google LLC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "oed1Dh9SeIlD"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A0UbyyBOeKmV"
},
"source": [
"# RAG with EmbeddingGemma\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/tutorials/RAG_with_EmbeddingGemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ND35JUp9ecq2"
},
"source": [
"EmbeddingGemma is a lightweight, open embedding model designed for fast, high-quality retrieval on everyday devices like mobile phones. At only 308 million parameters, it's efficient enough to run advanced AI techniques, such as Retrieval Augmented Generation (RAG), directly on your local machine with no internet connection required.\n",
"\n",
"## Setup\n",
"\n",
"Before starting this tutorial, complete the following steps:\n",
"\n",
"* Get access to EmbeddingGemma by logging into [Hugging Face](https://huggingface.co/google/embeddinggemma-300M) and selecting **Acknowledge license** for a Gemma model.\n",
"* Select a Colab runtime with sufficient resources to run\n",
" the Gemma model size you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n",
"* Generate a Hugging Face [Access Token](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-token) and use it to login from Colab.\n",
"\n",
"This notebook will run on an NVIDIA T4 GPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SZ8cw1nPf-NV"
},
"source": [
"### Install Python packages\n",
"\n",
"Install the libraries required for running the EmbeddingGemma model and generating embeddings. Sentence Transformers is a Python framework for text and image embeddings. For more information, see the [Sentence Transformers](https://www.sbert.net/) documentation."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "daXx6O20Q7M0"
},
"outputs": [],
"source": [
"!pip install -q -U sentence-transformers transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kYiTsNFSjGJH"
},
"source": [
"After you have accepted the license, you need a valid Hugging Face Token to access the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eLagJ9aff9Ks"
},
"outputs": [],
"source": [
"# Login into Hugging Face Hub\n",
"from huggingface_hub import login\n",
"login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IiDcW_rmHBfx"
},
"source": [
"### Load language model\n",
"\n",
"You will use Gemma 4 E2B to generate responses."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "HX2JFDQI-vg8"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0b54b8b91da46fdb7ba8fd3aecb5002",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4291694230e74608a2808adde451bd0f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/10.2G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb31547f287441aba370d8e7a5fc351e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/1951 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0900cc228bed472094eb986719edfde4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/208 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3d195cea1ce044f4827cf06412aed5ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3bdb49b389aa4abfbb382fccaceb32be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/32.2M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "93e44e5dd0fe40d49e0cda367d98aeca",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"chat_template.jinja: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load Gemma\n",
"from transformers import pipeline\n",
"\n",
"MODEL_ID = \"google/gemma-4-E2B-it\"\n",
"\n",
"pipeline = pipeline(\n",
" task=\"text-generation\",\n",
" model=MODEL_ID,\n",
" device_map=\"auto\",\n",
" dtype=\"auto\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eAg-c23Wh0th"
},
"source": [
"### Load embedding model\n",
"\n",
"Use the `sentence-transformers` libraries to create an instance of a model class with EmbeddingGemma."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "6Jj1WiTSRRk-"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c5dc65f501e402fb5ec67d094d925e7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"modules.json: 0%| | 0.00/573 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10b836de41a0410d8963be637ffa6b9d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config_sentence_transformers.json: 0%| | 0.00/997 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "68e29095344e4d24ac3898638f5a2b0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"README.md: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "376438be53e14e4b808ce63de0d32cb2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"sentence_bert_config.json: 0%| | 0.00/58.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7f2a5a56690e4ed5950ad0c278cc20c7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "264f0c21602640bd9ddfa9d405b5613f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/1.21G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "70eb603cffa948cc895046a8238abbae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/314 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aa608efe38f448898f8a01940a3684df",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b12b1756d9ac4145ae70595454e0e036",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/33.4M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6e735942c07444ebfcf2702673762b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"added_tokens.json: 0%| | 0.00/35.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ff92bb744fd54211b20f04aedebaa26d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/662 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "114a2560d2124889932f1a6436c4d6ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/312 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "292f471e215d4ac8a490508ce6963b01",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/134 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dce2f7bc57134d0180f3accdec8d5556",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"2_Dense/model.safetensors: 0%| | 0.00/9.44M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c245d417dc9f4d71850853a107379b16",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/134 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18def66743ae4738b940a4b20c434545",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"3_Dense/model.safetensors: 0%| | 0.00/9.44M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device: cuda:0\n",
"SentenceTransformer(\n",
" (0): Transformer({'transformer_task': 'feature-extraction', 'modality_config': {'text': {'method': 'forward', 'method_output_name': 'last_hidden_state'}}, 'module_output_name': 'token_embeddings', 'architecture': 'Gemma3TextModel'})\n",
" (1): Pooling({'embedding_dimension': 768, 'pooling_mode': 'mean', 'include_prompt': True})\n",
" (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity', 'module_input_name': 'sentence_embedding', 'module_output_name': 'sentence_embedding'})\n",
" (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity', 'module_input_name': 'sentence_embedding', 'module_output_name': 'sentence_embedding'})\n",
" (4): Normalize({})\n",
")\n",
"Total number of parameters in the model: 307581696\n"
]
}
],
"source": [
"import torch\n",
"from sentence_transformers import SentenceTransformer\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"model_id = \"google/embeddinggemma-300M\"\n",
"model = SentenceTransformer(model_id).to(device=device)\n",
"\n",
"print(f\"Device: {model.device}\")\n",
"print(model)\n",
"print(\"Total number of parameters in the model:\", sum([p.numel() for _, p in model.named_parameters()]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8o2-nOX-aqRS"
},
"source": [
"### Using Prompts with EmbeddingGemma\n",
"\n",
"For RAG systems, use the following `prompt_name` values to create specialized embeddings for your queries and documents:\n",
"\n",
"* **For Queries:** Use `prompt_name=\"Retrieval-query\"`.<br>\n",
" ```python\n",
" query_embedding = model.encode(\n",
" \"How do I use prompts with this model?\",\n",
" prompt_name=\"Retrieval-query\"\n",
" )\n",
" ```\n",
"\n",
"* **For Documents:** Use `prompt_name=\"Retrieval-document\"`. To further improve document embeddings, you can also include a title by using the `prompt` argument directly:<br>\n",
" * **With a title:**<br>\n",
" ```python\n",
" doc_embedding = model.encode(\n",
" \"The document text...\",\n",
" prompt=\"title: Using Prompts in RAG | text: \"\n",
" )\n",
" ```\n",
" * **Without a title:**<br>\n",
" ```python\n",
" doc_embedding = model.encode(\n",
" \"The document text...\",\n",
" prompt=\"title: none | text: \"\n",
" )\n",
" ```\n",
"\n",
"### Further Reading\n",
"\n",
"* For details on all available EmbeddingGemma prompts, see the [model card](http://ai.google.dev/gemma/docs/embeddinggemma/model_card#prompt_instructions).\n",
"* For general information on prompt templates, see the [Sentence Transformer documentation](https://sbert.net/examples/sentence_transformer/applications/computing-embeddings/README.html#prompt-templates).\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "Y5hVNF3F-qZ7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available tasks:\n",
" query: \"task: search result | query: \"\n",
" document: \"title: none | text: \"\n",
" BitextMining: \"task: search result | query: \"\n",
" Clustering: \"task: clustering | query: \"\n",
" Classification: \"task: classification | query: \"\n",
" InstructionRetrieval: \"task: code retrieval | query: \"\n",
" MultilabelClassification: \"task: classification | query: \"\n",
" PairClassification: \"task: sentence similarity | query: \"\n",
" Reranking: \"task: search result | query: \"\n",
" Retrieval: \"task: search result | query: \"\n",
" Retrieval-query: \"task: search result | query: \"\n",
" Retrieval-document: \"title: none | text: \"\n",
" STS: \"task: sentence similarity | query: \"\n",
" Summarization: \"task: summarization | query: \"\n"
]
}
],
"source": [
"print(\"Available tasks:\")\n",
"for name, prefix in model.prompts.items():\n",
" print(f\" {name}: \\\"{prefix}\\\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eIfWZ_z3xDZq"
},
"source": [
"## Simple RAG example\n",
"\n",
"Retrieval is the task of finding the most relevant pieces of information from a large collection (a database, a set of documents, a website) based on the meaning of a query, not just keywords.\n",
"\n",
"Imagine you work for a company, and you need to find information from the internal employee handbook, which is stored as a collection of hundreds of documents."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"cellView": "form",
"id": "fbaiy-CXRAs7"
},
"outputs": [],
"source": [
"#@title Corp knowledge base\n",
"corp_knowledge_base = [\n",
" {\n",
" \"category\": \"HR & Leave Policies\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Procedure for Unscheduled Absence\",\n",
" \"content\": \"In the event of an illness or emergency preventing you from working, please notify both your direct manager and the HR department via email by 9:30 AM JST. The subject line should be 'Sick Leave - [Your Name]'. If the absence extends beyond two consecutive days, a doctor's certificate (診断書) will be required upon your return.\"\n",
" },\n",
" {\n",
" \"title\": \"Annual Leave Policy\",\n",
" \"content\": \"Full-time employees are granted 10 days of annual paid leave in their first year. This leave is granted six months after the date of joining and increases each year based on length of service. For example, an employee in their third year of service is entitled to 14 days per year. For a detailed breakdown, please refer to the attached 'Annual Leave Accrual Table'.\"\n",
" },\n",
" ]\n",
" },\n",
" {\n",
" \"category\": \"IT & Security\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Account Password Management\",\n",
" \"content\": \"If you have forgotten your password or your account is locked, please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions. For security reasons, the IT Help Desk cannot reset passwords over the phone or email. If you have not set up your security questions, please visit the IT support desk on the 12th floor of the Shibuya office with your employee ID card.\"\n",
" },\n",
" {\n",
" \"title\": \"Software Procurement Process\",\n",
" \"content\": \"All requests for new software must be submitted through the 'IT Service Desk' portal under the 'Software Request' category. Please include a business justification for the request. All software licenses require approval from your department head before procurement can begin. Please note that standard productivity software is pre-approved and does not require this process.\"\n",
" },\n",
" ]\n",
" },\n",
" {\n",
" \"category\": \"Finance & Expenses\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Expense Reimbursement Policy\",\n",
" \"content\": \"To ensure timely processing, all expense claims for a given month must be submitted for approval no later than the 5th business day of the following month. For example, all expenses incurred in July must be submitted by the 5th business day of August. Submissions after this deadline may be processed in the next payment cycle.\"\n",
" },\n",
" {\n",
" \"title\": \"Business Trip Expense Guidelines\",\n",
" \"content\": \"Travel expenses for business trips will, as a rule, be reimbursed based on the actual cost of the most logical and economical route. Please submit a travel expense application in advance when using the Shinkansen or airplanes. Taxis are permitted only when public transportation is unavailable or when transporting heavy equipment. Receipts are mandatory.\"\n",
" },\n",
" ]\n",
" },\n",
" {\n",
" \"category\": \"Office & Facilities\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Conference Room Booking Instructions\",\n",
" \"content\": \"All conference rooms in the Shibuya office can be reserved through your Calendar App. Create a new meeting invitation, add the attendees, and then use the 'Room Finder' feature to select an available room. Please be sure to select the correct floor. For meetings with more than 10 people, please book the 'Sakura' or 'Fuji' rooms on the 14th floor.\"\n",
" },\n",
" {\n",
" \"title\": \"Mail and Delivery Policy\",\n",
" \"content\": \"The company's mail services are intended for business-related correspondence only. For security and liability reasons, employees are kindly requested to refrain from having personal parcels or mail delivered to the Shibuya office address. The front desk will not be able to accept or hold personal deliveries.\"\n",
" },\n",
" ]\n",
" },\n",
"]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fvecfoko--hL"
},
"source": [
"And imagine you have a question like below."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "wN-WHf26J89m"
},
"outputs": [],
"source": [
"question = \"How do I reset my password?\" # @param [\"How many days of annual paid leave do I get?\", \"How do I reset my password?\", \"What travel expenses can be reimbursed for a business trip?\", \"Can I receive personal packages at the office?\"] {type:\"string\", allow-input: true}\n",
"\n",
"# Define a minimum confidence threshold for a match to be considered valid\n",
"similarity_threshold = 0.4 # @param {\"type\":\"slider\",\"min\":0,\"max\":1,\"step\":0.1}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2CSeSmF7OuMB"
},
"source": [
"Search relevant document from the corporate knowledge base."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "NngqWUxOyrLS"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 1: Finding the best category...\n",
"['HR & Leave Policies', 'IT & Security', 'Finance & Expenses', 'Office & Facilities']\n",
"tensor([[0.5063, 0.5937, 0.5076, 0.4221]])\n",
" `-> ✅ Category Found: 'IT & Security' (Score: 0.59)\n",
"\n",
"Step 2: Finding the best document in that category...\n",
"['Account Password Management', 'Software Procurement Process']\n",
"tensor([[0.5829, 0.1531]])\n",
" `-> ✅ Document Found: 'Account Password Management' (Score: 0.58)\n"
]
}
],
"source": [
"# --- Helper Functions for Semantic Search ---\n",
"\n",
"def _calculate_best_match(similarities):\n",
" print(similarities)\n",
" if similarities is None or similarities.nelement() == 0:\n",
" return None, 0.0\n",
"\n",
" # Find the index and value of the highest score\n",
" best_index = similarities.argmax().item()\n",
" best_score = similarities[0, best_index].item()\n",
"\n",
" return best_index, best_score\n",
"\n",
"def find_best_category(model, query, candidates):\n",
" \"\"\"\n",
" Finds the most relevant category from a list of candidates.\n",
"\n",
" Args:\n",
" model: The SentenceTransformer model.\n",
" query: The user's query string.\n",
" candidates: A list of category name strings.\n",
"\n",
" Returns:\n",
" A tuple containing the index of the best category and its similarity score.\n",
" \"\"\"\n",
" if not candidates:\n",
" return None, 0.0\n",
"\n",
" # Encode the query and candidate categories for classification\n",
" query_embedding = model.encode(query, prompt_name=\"Classification\")\n",
" candidate_embeddings = model.encode(candidates, prompt_name=\"Classification\")\n",
"\n",
" print(candidates)\n",
" return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))\n",
"\n",
"def find_best_doc(model, query, candidates):\n",
" \"\"\"\n",
" Finds the most relevant document from a list of candidates.\n",
"\n",
" Args:\n",
" model: The SentenceTransformer model.\n",
" query: The user's query string.\n",
" candidates: A list of document dictionaries, each with 'title' and 'content'.\n",
"\n",
" Returns:\n",
" A tuple containing the index of the best document and its similarity score.\n",
" \"\"\"\n",
" if not candidates:\n",
" return None, 0.0\n",
"\n",
" # Encode the query for retrieval\n",
" query_embedding = model.encode(query, prompt_name=\"Retrieval-query\")\n",
"\n",
" # Encode the document for similarity check\n",
" doc_texts = [\n",
" f\"title: {doc.get('title', 'none')} | text: {doc.get('content', '')}\"\n",
" for doc in candidates\n",
" ]\n",
" candidate_embeddings = model.encode(doc_texts)\n",
"\n",
" print([doc['title'] for doc in candidates])\n",
"\n",
" # Calculate cosine similarity\n",
" return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))\n",
"\n",
"# --- Main Search Logic ---\n",
"\n",
"# In your application, `best_document` would result from a search.\n",
"# We initialize it to None to ensure it always exists.\n",
"best_document = None\n",
"\n",
"# 1. Find the most relevant category\n",
"print(\"Step 1: Finding the best category...\")\n",
"categories = [item[\"category\"] for item in corp_knowledge_base]\n",
"best_category_index, category_score = find_best_category(\n",
" model, question, categories\n",
")\n",
"\n",
"# Check if the category score meets the threshold\n",
"if category_score < similarity_threshold:\n",
" print(f\" `-> 🤷 No relevant category found. The highest score was only {category_score:.2f}.\")\n",
"else:\n",
" best_category = corp_knowledge_base[best_category_index]\n",
" print(f\" `-> ✅ Category Found: '{best_category['category']}' (Score: {category_score:.2f})\")\n",
"\n",
" # 2. Find the most relevant document ONLY if a good category was found\n",
" print(\"\\nStep 2: Finding the best document in that category...\")\n",
" best_document_index, document_score = find_best_doc(\n",
" model, question, best_category[\"documents\"]\n",
" )\n",
"\n",
" # Check if the document score meets the threshold\n",
" if document_score < similarity_threshold:\n",
" print(f\" `-> 🤷 No relevant document found. The highest score was only {document_score:.2f}.\")\n",
" else:\n",
" best_document = best_category[\"documents\"][best_document_index]\n",
" # 3. Display the final successful result\n",
" print(f\" `-> ✅ Document Found: '{best_document['title']}' (Score: {document_score:.2f})\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zK9T5rRGAMDw"
},
"source": [
"Next, generate the answer with the retrieved context"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "FrwKySpMASpt"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question🙋‍♂️: How do I reset my password?\n",
"Using document: Account Password Management\n",
"Answer🤖: Please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions.\n"
]
}
],
"source": [
"from transformers import GenerationConfig\n",
"MODEL_ID = \"google/gemma-4-E2B-it\"\n",
"config = GenerationConfig.from_pretrained(MODEL_ID)\n",
"config.max_new_tokens = 512\n",
"\n",
"qa_prompt_template = \"\"\"Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write \"I don't know.\"\n",
"\n",
"---\n",
"CONTEXT:\n",
"{context}\n",
"---\n",
"QUESTION:\n",
"{question}\n",
"\"\"\"\n",
"\n",
"# First, check if a valid document was found before proceeding.\n",
"if best_document and \"content\" in best_document:\n",
" # If the document exists and has a \"content\" key, generate the answer.\n",
" context = best_document[\"content\"]\n",
"\n",
" prompt = qa_prompt_template.format(context=context, question=question)\n",
"\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [{\"type\": \"text\", \"text\": prompt}],\n",
" },\n",
" ]\n",
"\n",
" print(\"Question🙋‍♂️: \" + question)\n",
" # This part assumes your pipeline and response parsing logic are correct\n",
" answer = pipeline(messages, generation_config=config)[0][\"generated_text\"][1][\"content\"]\n",
" print(\"Using document: \" + best_document[\"title\"])\n",
" print(\"Answer🤖: \" + answer)\n",
"\n",
"else:\n",
" # If best_document is None or doesn't have content, give a direct response.\n",
" print(\"Question🙋‍♂️: \" + question)\n",
" print(\"Answer🤖: I'm sorry, I could not find a relevant document to answer that question.\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h4J4pFA3IK1d"
},
"source": [
"## Summary and next steps\n",
"\n",
"You have now learned how to build a practical RAG system with EmbeddingGemma.\n",
"\n",
"Explore what more you can do with EmbeddingGemma:\n",
"\n",
"* [Generate embeddings with Sentence Transformers](https://ai.google.dev/gemma/docs/embeddinggemma/inference-embeddinggemma-with-sentence-transformers)\n",
"* [Fine-tune EmbeddingGemma](https://ai.google.dev/gemma/docs/embeddinggemma/fine-tuning-embeddinggemma-with-sentence-transformers)\n",
"* [Mood Palette Generator](https://huggingface.co/spaces/google/mood-palette), an interactive application using EmbeddingGemma"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "RAG_with_EmbeddingGemma.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@@ -0,0 +1,99 @@
# Gemma
[![Unittests](https://github.com/google-deepmind/gemma/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-deepmind/gemma/actions/workflows/pytest_and_autopublish.yml)
[![PyPI version](https://badge.fury.io/py/gemma.svg)](https://badge.fury.io/py/gemma)
[![Documentation Status](https://readthedocs.org/projects/gemma-llm/badge/?version=latest)](https://gemma-llm.readthedocs.io/en/latest/?badge=latest)
[Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language
Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini
research and technology.
This repository contains the implementation of the
[`gemma`](https://pypi.org/project/gemma/) PyPI package. A
[JAX](https://github.com/jax-ml/jax) library to use and fine-tune Gemma.
For examples and use cases, see our
[documentation](https://gemma-llm.readthedocs.io/). Please
report issues and feedback in
[our GitHub](https://github.com/google-deepmind/gemma/issues).
### Installation
1. Install JAX for CPU, GPU or TPU. Follow the instructions on
[the JAX website](https://jax.readthedocs.io/en/latest/installation.html).
1. Run
```sh
pip install gemma
```
### Examples
Here is a minimal example to have a multi-turn, multi-modal conversation with
Gemma:
```python
from gemma import gm
# Model and parameters (Gemma 4)
model = gm.nn.Gemma4_E4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT)
# Example of multi-turn conversation
sampler = gm.text.ChatSampler(
model=model,
params=params,
multi_turn=True,
)
prompt = """Which of the 2 images do you prefer ?
Image 1: <|image|>
Image 2: <|image|>
Write your answer as a poem."""
out0 = sampler.chat(prompt, images=[image1, image2])
out1 = sampler.chat('What about the other image ?')
```
The same `ChatSampler` API works with all Gemma versions (2, 3, 3n, 4).
Our documentation contains various Colabs and tutorials, including:
* [Sampling](https://gemma-llm.readthedocs.io/en/latest/colab_sampling.html)
* [Multi-modal](https://gemma-llm.readthedocs.io/en/latest/colab_multimodal.html)
* [Fine-tuning](https://gemma-llm.readthedocs.io/en/latest/colab_finetuning.html)
* [LoRA](https://gemma-llm.readthedocs.io/en/latest/colab_lora_sampling.html)
* ...
Additionally, our
[examples/](https://github.com/google-deepmind/gemma/tree/main/examples) folder
contain additional scripts to fine-tune and sample with Gemma.
### Learn more about Gemma
* To use this library: [Gemma documentation](https://gemma-llm.readthedocs.io/)
* Technical reports for metrics and model capabilities:
* [Gemma 1](https://goo.gle/GemmaReport)
* [Gemma 2](https://goo.gle/gemma2report)
* [Gemma 3](https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf)
* Gemma 4 (Coming soon)
* Other Gemma implementations and doc on the
[Gemma ecosystem](https://ai.google.dev/gemma/docs)
### Downloading the models
To download the model weights. See
[our documentation](https://gemma-llm.readthedocs.io/en/latest/checkpoints.html).
### System Requirements
Gemma can run on a CPU, GPU and TPU. For GPU, we recommend 8GB+ RAM on GPU for
The 2B checkpoint and 24GB+ RAM on GPU are used for the 7B checkpoint.
### Contributing
We welcome contributions! Please read our [Contributing Guidelines](./CONTRIBUTING.md) before submitting a pull request.
*This is not an official Google product.*
File diff suppressed because one or more lines are too long
@@ -0,0 +1,568 @@
{
"cells": [
{
"metadata": {
"id": "-KkvqLgjiIdD"
},
"cell_type": "markdown",
"source": [
"# Tool Use\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/tool_use.ipynb)\n",
"\n",
"Demo to show how to use tool-use with Gemma library.\n",
"\n",
"Note: The Gemma 1, 2 and 3 models were not specifically trained for tool use. This is more a proof-of-concept than an officially supported feature."
]
},
{
"metadata": {
"id": "gcNRfVEnj4aq"
},
"cell_type": "code",
"source": [
"!pip install -q gemma"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {
"executionInfo": {
"elapsed": 2221,
"status": "ok",
"timestamp": 1749202985345,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "k1ZAgLg1j9NT"
},
"cell_type": "code",
"source": [
"# Common imports\n",
"import os\n",
"import datetime\n",
"\n",
"# Gemma imports\n",
"from gemma import gm"
],
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"id": "139lZszJj_CC"
},
"cell_type": "markdown",
"source": [
"By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
]
},
{
"metadata": {
"executionInfo": {
"elapsed": 2,
"status": "ok",
"timestamp": 1749138071985,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "VtlWWLIYj_LJ"
},
"cell_type": "code",
"source": [
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
],
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"id": "31JPZb5RkD_p"
},
"cell_type": "markdown",
"source": [
"Load the model and the params."
]
},
{
"metadata": {
"executionInfo": {
"elapsed": 39057,
"status": "ok",
"timestamp": 1749203024713,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "RsAo6k4_kEJS",
"outputId": "e10afb5c-6c81-42e8-e590-a39ea4ef3bf7"
},
"cell_type": "code",
"source": [
"model = gm.nn.Gemma3_4B()\n",
"\n",
"params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)"
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:2025-06-06 02:43:16,896:jax._src.xla_bridge:749: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'\n",
"INFO:2025-06-06 02:43:16,897:jax._src.xla_bridge:749: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got \n",
"INFO:2025-06-06 02:43:16,900:jax._src.xla_bridge:749: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'\n",
"INFO:2025-06-06 02:43:16,901:jax._src.xla_bridge:749: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'\n"
]
}
],
"execution_count": 4
},
{
"metadata": {
"id": "p108c5yIlYH7"
},
"cell_type": "markdown",
"source": [
"## Using existing tools\n",
"\n",
"If you're familiar with the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial, using tool-use differ in two ways:\n",
"\n",
"1. Using the `gm.text.ToolSampler` rather than the `gm.text.ChatSampler`.\n",
"2. Passing the `tools=` you want to use to the sampler.\n",
"\n",
"For example:"
]
},
{
"metadata": {
"colab": {
"height": 594
},
"executionInfo": {
"elapsed": 50615,
"status": "ok",
"timestamp": 1749138791069,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "iRCV5h8BlVX6",
"outputId": "b3b5d83d-8a8b-4982-fc8f-d409fb8b38a9"
},
"cell_type": "code",
"source": [
"sampler = gm.text.ToolSampler(\n",
" model=model,\n",
" params=params,\n",
" tools=[\n",
" gm.tools.Calculator(),\n",
" gm.tools.FileExplorer(),\n",
" ],\n",
" print_stream=True,\n",
")\n",
"\n",
"output = sampler.chat('I have a serie `Sn+1 = cos(Sn) * 2`. Using the calculator, compute the steps 0-4 for S0 = 3')"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: I need to compute S0, S1, S2, S3, and S4 using the given recursive formula Sn+1 = cos(Sn) * 2 and S0 = 3.\n",
"Let's start with S0 = 3.\n",
"S1 = cos(S0) * 2 = cos(3) * 2\n",
"S2 = cos(S1) * 2 = cos(cos(3) * 2) * 2\n",
"S3 = cos(S2) * 2 = cos(cos(cos(3) * 2) * 2) * 2\n",
"S4 = cos(S3) * 2 = cos(cos(cos(cos(3) * 2) * 2)) * 2\n",
"\n",
"I will use the calculator to compute these values.\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(3) * 2\"}\n",
"\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: -1.9799849932008908]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S1 = cos(S0) * 2 = cos(3) * 2 = -1.9799849932008908 * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"-1.9799849932008908 * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: -3.9599699864017817]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S2 = cos(S1) * 2 = cos(-3.9599699864017817) * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-3.9599699864017817) * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: -1.3668134299076982]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S3 = cos(S2) * 2 = cos(-1.3668134299076982) * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-1.3668134299076982) * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: 0.4051424976130353]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S4 = cos(S3) * 2 = cos(0.4051424976130353) * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(0.4051424976130353) * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: 1.8380924822033438]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The values are: S0 = 3, S1 = -3.9599699864017817, S2 = -1.3668134299076982, S3 = 0.4051424976130353, S4 = 1.8380924822033438"
]
}
],
"execution_count": 10
},
{
"metadata": {
"id": "FAI54F-Blkan"
},
"cell_type": "markdown",
"source": [
"Note: Only the final model answer is returned. You can access the conversation history, including all intermediates tool calls and output through `sampler.turns` property."
]
},
{
"metadata": {
"id": "D0_IIS1Nlfuw"
},
"cell_type": "markdown",
"source": [
"## Creating your own tool\n",
"\n",
"To create your own tool, you can inherit from the `gm.tools.Tool` class. You should provide:\n",
"\n",
"* A description & example, so the model knows how to use your tool\n",
"* Implement the `call` method. The `call` function can take arbitrary `**kwargs`, but the name of the args should match the ones defined in `tool_kwargs` and `tool_kwargs_doc`"
]
},
{
"metadata": {
"executionInfo": {
"elapsed": 55,
"status": "ok",
"timestamp": 1749203934196,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "XqmQcfdI0oEl"
},
"cell_type": "code",
"source": [
"class DateTime(gm.tools.Tool):\n",
" \"\"\"Tool to access the current date.\"\"\"\n",
"\n",
" DESCRIPTION = 'Access the current date, time,...'\n",
" EXAMPLE = gm.tools.Example(\n",
" query='Which day of the week are we today ?',\n",
" thought='The `datetime.strptime` uses %a for day of the week',\n",
" tool_kwargs={'format': '%a'},\n",
" tool_kwargs_doc={'format': '<ANY datetime.strptime expression>'},\n",
" result='Sat',\n",
" answer='Today is Saturday.',\n",
" )\n",
"\n",
" def call(self, format: str) -> str:\n",
" dt = datetime.datetime.now()\n",
" return dt.strftime(format)\n"
],
"outputs": [],
"execution_count": 7
},
{
"metadata": {
"id": "sSxYhXPuuXYp"
},
"cell_type": "markdown",
"source": [
"The tool can then be used in the sampler:"
]
},
{
"metadata": {
"colab": {
"height": 118
},
"executionInfo": {
"elapsed": 2156,
"status": "ok",
"timestamp": 1749204833094,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "9S8xB2B-0cbW",
"outputId": "fccc0e89-e922-4184-8b77-800041cdd77e"
},
"cell_type": "code",
"source": [
"sampler = gm.text.ToolSampler(\n",
" model=model,\n",
" params=params,\n",
" tools=[\n",
" DateTime(),\n",
" ],\n",
" print_stream=True,\n",
")\n",
"\n",
"output = sampler.chat('Which date are we today ?')"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: I need to get the current date.\n",
"{\"tool_name\": \"datetime\", \"format\": \"%Y-%m-%d\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: 2025-06-06]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Today is June 6th, 2025."
]
}
],
"execution_count": 9
},
{
"metadata": {
"id": "esIpCjhxzHmf"
},
"cell_type": "markdown",
"source": [
"## Next steps\n",
"\n",
"* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.\n",
"* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.\n"
]
}
],
"metadata": {
"colab": {
"last_runtime": {},
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@@ -0,0 +1,130 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example config for finetuning Gemma for a classification task.
* Input: A text to classify.
* Output: A classification label. The pre-trained Gemma model is trained to
predict one world among 256.000. Here, we're finetuning to predict only 2
tokens among the 256.000 available.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/classification.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
"""Get the default hyperparameter configuration."""
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(training=True),
# Model definition
model=gm.nn.Gemma3_4B(
tokens="batch.sentence",
return_last_only=True,
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.label",
),
},
optimizer=optax.adafactor(learning_rate=1e-4),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(training=False),
),
},
)
def _make_dataset(training: bool) -> kd.data.Pipeline:
# Dict key names from the dataset
_INPUT_FIELD = "sentence" # pylint: disable=invalid-name
_LABEL_FIELD = "label" # pylint: disable=invalid-name
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="glue/cola",
split="train" if training else "validation",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=8,
transforms=[
# Process the input text
# TFDS datasets returns `bytes`, so convert them to `str`
gm.data.DecodeBytes(key=_INPUT_FIELD),
gm.data.FormatText(
key=_INPUT_FIELD,
template="""<start_of_turn>user
Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No.
Sentence: {text}<end_of_turn>
<start_of_turn>model""",
),
gm.data.Tokenize(
key=_INPUT_FIELD,
tokenizer=tokenizer,
add_bos=True,
),
gm.data.Pad(
key=_INPUT_FIELD,
max_length=128,
),
# Process the label
gm.data.MapInts(
key=_LABEL_FIELD,
# Rather than predicting the token 0 and 1, we are using the
# token 1294 and 3553 which respectivelly correspond to "No" and
# "Yes". We do this because those token already contain semantic
# information, so even zero-shot prediction without any
# finetuning has better than random performances.
old_to_new={
0: 1294, # Token -> "No"
1: 3553, # Token -> "Yes"
},
),
kd.data.Rearrange(
key=_LABEL_FIELD,
pattern="... -> ... 1", # For shape compatibility with the loss.
),
],
)
@@ -0,0 +1,122 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""DPO Example.
DPO works by running two answers (one prefered and one rejected) into both
the reference model and the model to finetune. Then the DPO loss is used to
increase the likelihood of generating the preferred answer.
Implementation wise, this is done by:
* Wrapping the model inside a `gm.nn.AnchoredPolicy` (which runs both the
model and the reference frozen model)
* Using the `gm.ckpts.AnchoredPolicyLoader` to restore the weights, so the
weights are correctly mapped to inside `gm.nn.AnchoredPolicy`.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/dpo.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
"""Get the default hyperparameter configuration."""
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(training=True),
# Model definition
model=gm.nn.AnchoredPolicy(
policy=gm.nn.Gemma3_4B(tokens="batch.tokens", text_only=True),
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.AnchoredPolicyLoader(
policy=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
),
# Training
num_train_steps=10_000,
train_losses={
"dpo": gm.losses.DpoLoss(
tokens="batch.targets",
sequence_mask="batch.mask",
policy_logits="preds.policy.logits",
anchor_logits="preds.anchor.logits",
),
},
optimizer=optax.adafactor(learning_rate=1e-4),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
# "test": kd.evals.Evaluator(
# run=kd.evals.EveryNSteps(1000),
# ds=_make_dataset(training=False),
# ),
},
)
def _make_dataset(training: bool) -> kd.data.Pipeline:
# TODO(epot): !!!!
max_length = 512
batch_size = 16
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.HuggingFace(
path="argilla/distilabel-math-preference-dpo",
split="train",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=batch_size,
transforms=[
# Only keep the fields we need.
kd.data.Elements(
keep=["instruction", "chosen_response", "rejected_response"]
),
# Create the model inputs and loss mask.
gm.data.ContrastiveTask(
in_prompt="instruction",
in_chosen="chosen_response",
in_rejected="rejected_response",
out_tokens="tokens",
out_targets="targets",
out_mask="mask",
tokenizer=tokenizer,
# Padding parameters
max_length=max_length,
# TODO(epot): Run stats (how many examples are we dropping?)
truncate=True,
),
],
)
@@ -0,0 +1,154 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example of Gemma finetuning using LoRA.
This example is based on the `seq2seq.py` example. See the
docstring of that file for more details.
The changes to use LoRA are:
* `model`: Use `gm.nn.LoRA()` wrapper to add `LoRA` adapters to the
model.
* `init_transform`: Use `gm.ckpts.SkipLoRA()` wrapper to only restore the
non-LoRA weights.
* `optimizer`: Use `kd.optim.partial_updates` wrapper to only train the LoRA
weights.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/lora.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
batch_size = 16
max_length = 512
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(
training=True,
batch_size=batch_size,
max_length=max_length,
),
# Model definition
model=gm.nn.LoRA(
rank=4,
model=gm.nn.Gemma3_4B(
tokens="batch.input",
# TODO(epot): At the moment, LoRA fine-tuning with multimodal
# is not supported. Willbe fixed soon.
text_only=True,
),
),
# Load the weights from the pretrained checkpoint
# Use `SkipLoRA` as the original checkpoint does not contain the LoRA
# weights.
init_transform=gm.ckpts.SkipLoRA(
wrapped=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
)
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
# TODO(epot): Add Gradient accumenlation.
optimizer=kd.optim.partial_updates(
optax.adafactor(learning_rate=0.005),
# We only optimize the LoRA weights. The rest of the model is frozen.
mask=kd.optim.select("lora"),
),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(
training=False,
batch_size=batch_size,
max_length=max_length,
),
),
# The sampler evaluator run inference on a few prompts from the
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
max_new_tokens=150, # Sampling parameters
num_batches=1, # Only predict a single example (batch_size=None)
ds=_make_dataset(training=False, sampling=True),
),
},
)
def _make_dataset(
*,
training: bool,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="mtnt/en-fr",
split="train" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://www.tensorflow.org/datasets/catalog/mtnt
in_prompt="src",
in_response="dst",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
sampling=sampling,
),
],
)
@@ -0,0 +1,164 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example of Gemma finetuning for an image captioning task.
Example:
Prompt:
```
<start_of_turn>user
<start_of_image><end_of_turn>
<start_of_turn>model
```
Target:
```
A diagram showing a circuit with a battery, lamp, and switch.<end_of_turn>
```
Here, the prompt only contains the `<start_of_image>` to indicate an image
is inserted.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/multimodal.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
import jax.numpy as jnp
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
batch_size = 32
max_length = 200
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(
training=True,
batch_size=batch_size,
max_length=max_length,
),
# Model definition
model=gm.nn.Gemma3_4B(
tokens="batch.input",
images="batch.image",
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
train_summaries={
"image": kd.summaries.ShowImages(images="batch.image", num_images=5),
},
optimizer=optax.adafactor(learning_rate=1e-3),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(
training=False,
batch_size=4,
max_length=max_length,
),
),
# The sampler evaluator run inference on a few prompts from the
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
max_new_tokens=50, # Sampling parameters
num_batches=3,
ds=_make_dataset(training=False, sampling=True),
summaries={
"image": kd.summaries.ShowImages(
images="batch.image", num_images=5
),
},
),
},
)
def _make_dataset(
*,
training: bool,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="ai2dcaption",
split="llava_15" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
# Only keep the fields we need.See fields at:
# https://www.tensorflow.org/datasets/catalog/ai2dcaption
kd.data.Elements(keep=["image", "caption"]),
# Create a new constant field
kd.data.AddConstants({"prompt": "<start_of_image>"}),
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
in_prompt="prompt",
in_response="caption",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
sampling=sampling,
),
kd.data.py.Resize(key="image", size=(800, 800)),
# TODO(epot): Make the `num_images` dimension optional
kd.data.Rearrange(key="image", pattern="... h w c -> ... 1 h w c"),
kd.data.Cast(key="image", dtype=jnp.uint8),
],
)
@@ -0,0 +1,133 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example of Gemma finetuning for a prompt -> response task.
This is a fork of the seq2seq example, but with sharding.
The only difference is the `sharding=kd.sharding.ShardingStrategy()`
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/sharding.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
batch_size = 16
max_length = 512
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(
training=True,
batch_size=batch_size,
max_length=max_length,
),
# Model definition
model=gm.nn.Gemma3_4B(
tokens="batch.input",
),
sharding=kd.sharding.ShardingStrategy(
params=kd.sharding.FSDPSharding(),
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
optimizer=optax.adafactor(learning_rate=1e-3),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(
training=False,
batch_size=batch_size,
max_length=max_length,
),
),
# The sampler evaluator run inference on a few prompts from the
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
max_new_tokens=50, # Sampling parameters
num_batches=1, # Only predict a single example (batch_size=None)
ds=_make_dataset(training=False, sampling=True),
),
},
)
def _make_dataset(
*,
training: bool,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="mtnt/en-fr",
split="train" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://www.tensorflow.org/datasets/catalog/mtnt
in_prompt="src",
in_response="dst",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
sampling=sampling,
),
],
)
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,250 @@
# Gemma.cpp API Server
This is an HTTP API server for gemma.cpp that implements the Google API protocol, allowing you to interact with Gemma models through REST API endpoints compatible with the Google API format.
## Features
- **API-compatible**: Implements Google API endpoints
- **Unified client/server**: Single codebase supports both local and public API modes
- **Text generation**: Support for `generateContent` endpoint
- **Streaming support**: Server-Sent Events (SSE) for `streamGenerateContent`
- **Model management**: Support for `/v1beta/models` endpoint
- **Session management**: Maintains conversation context with KV cache
- **JSON responses**: All responses in Google API format
- **Error handling**: Proper HTTP status codes and error messages
## Building
The API server is built alongside the main gemma.cpp project:
```bash
# Configure the build
cmake -B build -DCMAKE_BUILD_TYPE=Release
# Build the API server and client
cmake --build build --target gemma_api_server gemma_api_client -j 8
```
The binaries will be created at:
- `build/gemma_api_server` - Local API server
- `build/gemma_api_client` - Unified client for both local and public APIs
## Usage
### Starting the Local API Server
```bash
./build/gemma_api_server \
--tokenizer path/to/tokenizer.spm \
--weights path/to/model.sbs \
--port 8080
```
**Required arguments:**
- `--tokenizer`: Path to the tokenizer file (`.spm`)
- `--weights`: Path to the model weights file (`.sbs`)
**Optional arguments:**
- `--port`: Port to listen on (default: 8080)
- `--model`: Model name for API endpoints (default: gemma3-4b)
### Using the Unified Client
#### With Local Server
```bash
# Interactive chat with local server
./build/gemma_api_client --interactive 1 --host localhost --port 8080
# Single prompt with local server
./build/gemma_api_client --prompt "Hello, how are you?"
```
#### With Public Google API
```bash
# Set API key and use public API
export GOOGLE_API_KEY="your-api-key-here"
./build/gemma_api_client --interactive 1
# Or pass API key directly
./build/gemma_api_client --api_key "your-api-key" --interactive 1
```
## API Endpoints
The server implements Google API endpoints:
### 1. Generate Content - `POST /v1beta/models/gemma3-4b:generateContent`
Generate a response for given content (non-streaming).
**Request:**
```json
{
"contents": [
{
"parts": [
{"text": "Why is the sky blue?"}
]
}
],
"generationConfig": {
"temperature": 0.9,
"topK": 1,
"maxOutputTokens": 1024
}
}
```
**Response:**
```json
{
"candidates": [
{
"content": {
"parts": [
{"text": "The sky appears blue because..."}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0
}
],
"promptFeedback": {
"safetyRatings": []
},
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 25,
"totalTokenCount": 30
}
}
```
### 2. Stream Generate Content - `POST /v1beta/models/gemma3-4b:streamGenerateContent`
Generate a response with Server-Sent Events (SSE) streaming.
**Request:** Same as above
**Response:** Stream of SSE events:
```
data: {"candidates":[{"content":{"parts":[{"text":"The"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}}
data: {"candidates":[{"content":{"parts":[{"text":" sky"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}}
data: [DONE]
```
### 3. List Models - `GET /v1beta/models`
List available models.
**Response:**
```json
{
"models": [
{
"name": "models/gemma3-4b",
"displayName": "Gemma3 4B",
"description": "Gemma3 4B model running locally"
}
]
}
```
## Example Usage
### Using curl with Local Server
```bash
# Generate content (non-streaming)
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [{"parts": [{"text": "Hello, how are you?"}]}],
"generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024}
}'
# Stream generate content (SSE)
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:streamGenerateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [{"parts": [{"text": "Tell me a story"}]}],
"generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024}
}'
# List models
curl http://localhost:8080/v1beta/models
```
### Multi-turn Conversation with curl
```bash
# First message
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [
{"parts": [{"text": "Hi, my name is Alice"}]}
]
}'
# Follow-up message with conversation history
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [
{"parts": [{"text": "Hi, my name is Alice"}]},
{"parts": [{"text": "Hello Alice! Nice to meet you."}]},
{"parts": [{"text": "What is my name?"}]}
]
}'
```
### Using Python
```python
import requests
# Generate content
response = requests.post('http://localhost:8080/v1beta/models/gemma3-4b:generateContent',
json={
'contents': [{'parts': [{'text': 'Explain quantum computing in simple terms'}]}],
'generationConfig': {
'temperature': 0.9,
'topK': 1,
'maxOutputTokens': 1024
}
}
)
result = response.json()
if 'candidates' in result and result['candidates']:
text = result['candidates'][0]['content']['parts'][0]['text']
print(text)
```
## Configuration Options
The Google API supports various generation configuration options:
- **temperature**: Controls randomness (0.0 to 2.0, default: 1.0)
- **topK**: Top-K sampling parameter (default: 1)
- **maxOutputTokens**: Maximum number of tokens to generate (default: 8192)
## Key Features
- **Unified Implementation**: Same codebase handles both local server and public API
- **Session Management**: Maintains conversation context using KV cache
- **Streaming Support**: Real-time token generation via Server-Sent Events
- **Error Handling**: Comprehensive error responses and HTTP status codes
- **Memory Efficient**: Optimized token processing and caching
## Compatibility
This implementation is compatible with:
- Google API format and endpoints
- Standard HTTP clients (curl, browsers, Python requests, etc.)
- Server-Sent Events (SSE) for streaming responses
- JSON request/response format
+532
View File
@@ -0,0 +1,532 @@
# gemma.cpp
gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
foundation models from Google.
For additional information about Gemma, see
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
gemma.cpp specific artifacts, are
[available on kaggle](https://www.kaggle.com/models/google/gemma-2).
## Who is this project for?
Modern LLM inference engines are sophisticated systems, often with bespoke
capabilities extending beyond traditional neural network runtimes. With this
comes opportunities for research and innovation through co-design of high level
algorithms and low-level computation. However, there is a gap between
deployment-oriented C++ inference runtimes, which are not designed for
experimentation, and Python-centric ML research frameworks, which abstract away
low-level computation through compilation.
gemma.cpp provides a minimalist implementation of Gemma-2, Gemma-3, and
PaliGemma-2 models, focusing on simplicity and directness rather than full
generality. This is inspired by vertically-integrated model implementations such
as [ggml](https://github.com/ggerganov/ggml),
[llama.c](https://github.com/karpathy/llama2.c), and
[llama.rs](https://github.com/srush/llama2.rs).
gemma.cpp targets experimentation and research use cases. It is intended to be
straightforward to embed in other projects with minimal dependencies and also
easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC
of supporting utilities). We use the [Google
Highway](https://github.com/google/highway) Library to take advantage of
portable SIMD for CPU inference.
For production-oriented edge deployments we recommend standard deployment
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
([all model variations here](https://www.kaggle.com/models/google/gemma)).
## Contributing
Community contributions large and small are welcome. See
[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md)
for additional notes contributing developers and [join the discord by following
this invite link](https://discord.gg/H5jCBAWxAe). This project follows
[Google's Open Source Community
Guidelines](https://opensource.google.com/conduct/).
> [!NOTE] Active development is currently done on the `dev` branch. Please open
> pull requests targeting `dev` branch instead of `main`, which is intended to
> be more stable.
## What's inside?
- LLM
- CPU-only inference for: Gemma 2-3, PaliGemma 2.
- Sampling with TopK and temperature.
- Backward pass (VJP) and Adam optimizer for Gemma research.
- Optimizations
- Mixed-precision (fp8, bf16, fp32, fp64 bit) GEMM:
- Designed for BF16 instructions, can efficiently emulate them.
- Automatic runtime autotuning 7 parameters per matrix shape.
- Weight compression integrated directly into GEMM:
- Custom fp8 format with 2..3 mantissa bits; tensor scaling.
- Also bf16, f32 and non-uniform 4-bit (NUQ); easy to add new formats.
- Infrastructure
- SIMD: single implementation via Highway. Chooses ISA at runtime.
- Tensor parallelism: CCX-aware, multi-socket thread pool.
- Disk I/O: memory map or parallel read (heuristic with user override).
- Custom format with forward/backward-compatible metadata serialization.
- Model conversion from Safetensors, not yet open sourced.
- Portability: Linux, Windows/OS X supported. CMake/Bazel. 'Any' CPU.
- Frontends
- C++ APIs with streaming for single query and batched inference.
- Basic interactive command-line app.
- Basic Python bindings (pybind11).
## Quick Start
### System requirements
Before starting, you should have installed:
- [CMake](https://cmake.org/)
- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at
least C++17.
- `tar` for extracting archives from Kaggle.
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
command line with
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
```sh
winget install --id Kitware.CMake
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
```
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
Visit the
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
and select `Model Variations |> Gemma C++`.
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
weights are higher fidelity, while 8-bit switched floating point weights enable
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
> [!NOTE] **Important**: We strongly recommend starting off with the
> `gemma2-2b-it-sfp` model to get up and running.
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
`ModelPrefix` function in `configs.cc`.
### Step 2: Extract Files
After filling out the consent form, the download should proceed to retrieve a
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
take a few minutes):
```
tar -xf archive.tar.gz
```
This should produce a file containing model weights such as `2b-it-sfp.sbs` and
a tokenizer file (`tokenizer.spm`). You may want to move these files to a
convenient directory location (e.g. the `build/` directory in this repo).
### Step 3: Build
The build system uses [CMake](https://cmake.org/). To build the gemma inference
runtime, create a build directory and generate the build files using `cmake`
from the top-level project directory. Note if you previous ran `cmake` and are
re-running with a different setting, be sure to delete all files in the `build/`
directory with `rm -rf build/*`.
#### Unix-like Platforms
```sh
cmake -B build
```
After running `cmake`, you can enter the `build/` directory and run `make` to
build the `./gemma` executable:
```sh
# Configure `build` directory
cmake --preset make
# Build project using make
cmake --build --preset make -j [number of parallel threads to use]
```
Replace `[number of parallel threads to use]` with a number - the number of
cores available on your system is a reasonable heuristic. For example, `make -j4
gemma` will build using 4 threads. If the `nproc` command is available, you can
use `make -j$(nproc) gemma` as a reasonable default for the number of threads.
If you aren't sure of the right value for the `-j` flag, you can simply run
`make gemma` instead and it should still build the `./gemma` executable.
> [!NOTE]
> On Windows Subsystem for Linux (WSL) users should set the number of
> parallel threads to 1. Using a larger number may result in errors.
If the build is successful, you should now have a `gemma` executable in the
`build/` directory.
#### Windows
```sh
# Configure `build` directory
cmake --preset windows
# Build project using Visual Studio Build Tools
cmake --build --preset windows -j [number of parallel threads to use]
```
If the build is successful, you should now have a `gemma.exe` executable in the
`build/` directory.
#### Bazel
```sh
bazel build -c opt --cxxopt=-std=c++20 :gemma
```
If the build is successful, you should now have a `gemma` executable in the
`bazel-bin/` directory.
#### Make
If you prefer Makefiles, @jart has made one available here:
https://github.com/jart/gemma3/blob/main/Makefile
### Step 4: Run
You can now run `gemma` from inside the `build/` directory.
`gemma` has the following required arguments:
Argument | Description | Example value
------------- | ---------------------------- | ---------------
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
Example invocation for the following configuration:
- weights file `gemma2-2b-it-sfp.sbs` (Gemma2 2B instruction-tuned model,
8-bit switched floating point).
- Tokenizer file `tokenizer.spm` (can omit for single-format weights files
created after 2025-05-06, or output by migrate_weights.cc).
```sh
./gemma \
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
```
### PaliGemma Vision-Language Model
This repository includes a version of the PaliGemma 2 VLM
([paper](https://arxiv.org/abs/2412.03555)). We provide a C++ implementation of
the PaliGemma 2 model here.
To use the version of PaliGemma included in this repository, build the gemma
binary as noted above in Step 3. Download the compressed weights and tokenizer
from
[Kaggle](https://www.kaggle.com/models/google/paligemma-2/gemmaCpp/paligemma2-3b-mix-224)
and run the binary as follows:
```sh
./gemma \
--tokenizer paligemma_tokenizer.model \
--weights paligemma2-3b-mix-224-sfp.sbs \
--image_file paligemma/testdata/image.ppm
```
Note that the image reading code is very basic to avoid depending on an image
processing library for now. We currently only support reading binary PPMs (P6).
So use a tool like `convert` to first convert your images into that format, e.g.
`convert image.jpeg -resize 224x224^ image.ppm`
(As the image will be resized for processing anyway, we can already resize at
this stage for slightly faster loading.)
The interaction with the image (using the mix-224 checkpoint) may then look
something like this:
```
> Describe the image briefly
A large building with two towers in the middle of a city.
> What type of building is it?
church
> What color is the church?
gray
> caption image
A large building with two towers stands tall on the water's edge. The building
has a brown roof and a window on the side. A tree stands in front of the
building, and a flag waves proudly from its top. The water is calm and blue,
reflecting the sky above. A bridge crosses the water, and a red and white boat
rests on its surface. The building has a window on the side, and a flag on top.
A tall tree stands in front of the building, and a window on the building is
visible from the water. The water is green, and the sky is blue.
```
### Migrating to single-file format
There is now a new format for the weights file, which is a single file that
allows to contain the tokenizer (and the model type) directly. A tool to migrate
from the multi-file format to the single-file format is available.
```sh
io/migrate_weights \
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
--output_weights .../gemma2-2b-it-sfp-single.sbs
```
After migration, you can omit the tokenizer argument like this:
```sh
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
```
### Troubleshooting and FAQs
**Problems building in Windows / Visual Studio**
Currently if you're using Windows, we recommend building in WSL (Windows
Subsystem for Linux). We are exploring options to enable other build
configurations, see issues for active discussion.
**Model does not respond to instructions and produces strange output**
A common issue is that you are using a pre-trained model, which is not
instruction-tuned and thus does not respond to instructions. Make sure you are
using an instruction-tuned model (`gemma2-2b-it-sfp`) and not a pre-trained
model (any model with a `-pt` suffix).
**What sequence lengths are supported?**
See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3
models larger than 1B, this is typically 32K but 128K would also work given
enough RAM. Note that long sequences will be slow due to the quadratic cost of
attention.
**How do I convert my fine-tune to a `.sbs` compressed model file?**
For PaliGemma 2 checkpoints, you can use python/convert_from_safetensors.py to
convert from safetensors format (tested with building via bazel). For an adapter
model, you will likely need to call merge_and_unload() to convert the adapter
model to a single-file format before converting it.
Here is how to use it using a bazel build of the compression library assuming
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
```sh
bazel build //compression/python:compression
BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression"
python3 -c "import site; print(site.getsitepackages())"
# Use your sites-packages file here:
ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
```
**What are some easy ways to make the model run faster?**
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
These are half the size of bf16 and thus use less memory bandwidth and cache
space.
2. Due to auto-tuning, the second and especially third query will be faster.
3. If you're on a laptop, make sure power mode is set to maximize performance
and saving mode is **off**. For most laptops, the power saving modes get
activated automatically if the computer is not plugged in.
4. Close other unused cpu-intensive applications.
5. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
cores get engaged.
We're also working on algorithmic and optimization approaches for faster
inference, stay tuned.
## Usage
`gemma` has different usage modes, controlled by the verbosity flag.
All usage modes are currently interactive, triggering text generation upon
newline input.
| Verbosity | Usage mode | Details |
| --------------- | ---------- | --------------------------------------------- |
| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. |
| `--verbosity 1` | Default | Standard user-facing terminal UI. |
| `--verbosity 2` | Detailed | Shows additional developer and debug info. |
### Interactive Terminal App
By default, verbosity is set to 1, bringing up a terminal-based interactive
interface when `gemma` is invoked:
```sh
$ ./gemma [...]
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
__/ | | | | |
|___/ |_| |_|
...
*Usage*
Enter an instruction and press enter (%C reset conversation, %Q quits).
*Examples*
- Write an email to grandma thanking her for the cookies.
- What are some historical attractions to visit around Massachusetts?
- Compute the nth fibonacci number in javascript.
- Write a standup comedy bit about WebGPU programming.
> What are some outdoorsy places to visit around Boston?
[ Reading prompt ] .....................
**Boston Harbor and Islands:**
* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history.
* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline.
* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective.
* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum.
**Forest and Nature:**
* **Forest Park:** Hike through a scenic forest with diverse wildlife.
* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting.
* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape.
...
```
### Usage as a Command Line Tool
For using the `gemma` executable as a command line tool, it may be useful to
create an alias for gemma.cpp with arguments fully specified:
```sh
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --verbosity 0"
```
Replace the above paths with your own paths to the model and tokenizer paths
from the download.
Here is an example of prompting `gemma` with a truncated input
file (using a `gemma2b` alias like defined above):
```sh
cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
```
> [!NOTE]
> CLI usage of gemma.cpp is experimental and should take context length
> limitations into account.
The output of the above command should look like:
```sh
[ Reading prompt ] [...]
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
Let's break down the code:
[...]
```
### Incorporating gemma.cpp as a Library in your Project
The easiest way to incorporate gemma.cpp in your own project is to pull in
gemma.cpp and dependencies using `FetchContent`. You can add the following to
your CMakeLists.txt:
```
include(FetchContent)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
FetchContent_MakeAvailable(highway)
```
Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific
commit hash if you would like to pin the library version.
After your executable is defined (substitute your executable name for
`[Executable Name]` below):
```
target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece)
FetchContent_GetProperties(gemma)
FetchContent_GetProperties(sentencepiece)
target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR})
target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR})
```
### Building gemma.cpp as a Library
gemma.cpp can also be used as a library dependency in your own project. The
shared library artifact can be built by modifying the make invocation to build
the `libgemma` target instead of `gemma`.
> [!NOTE]
> If you are using gemma.cpp in your own project with the `FetchContent` steps
> in the previous section, building the library is done automatically by `cmake`
> and this section can be skipped.
First, run `cmake`:
```sh
cmake -B build
```
Then, run `make` with the `libgemma` target:
```sh
cd build
make -j [number of parallel threads to use] libgemma
```
If this is successful, you should now have a `libgemma` library file in the
`build/` directory. On Unix platforms, the filename is `libgemma.a`.
## Independent Projects Using gemma.cpp
Some independent projects using gemma.cpp:
- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python)
- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma)
- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project)
If you would like to have your project included, feel free to get in touch or
submit a PR with a `README.md` edit.
## Acknowledgements and Contacts
gemma.cpp was started in fall 2023 by
[Austin Huang](mailto:austinvhuang@google.com) and
[Jan Wassenberg](mailto:janwas@google.com), and subsequently released February
2024 thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
Griffin support was implemented in April 2024 thanks to contributions by Andrey
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka. It was removed in 2025-09.
Gemma-2 support was implemented in June/July 2024 with the help of several
people.
PaliGemma support was implemented in September 2024 with contributions from
Daniel Keysers.
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
improvements, including major gains in efficiency, since the initial release.
This is not an officially supported Google product.
@@ -0,0 +1,7 @@
# Examples
In this directory are some simple examples illustrating usage of `gemma.cpp` as
a library beyond the interactive `gemma` app implemented in `run.cc`.
- `hello_world/` - minimal/template project for using `gemma.cpp` as a library.
It sets up the model state and generates text for a single hard coded prompt.
@@ -0,0 +1,186 @@
# Gemma in PyTorch
**Gemma** is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They include both text-only and multimodal decoder-only large language models, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links:
* [Gemma on Google AI](https://ai.google.dev/gemma)
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma-3)
* [Gemma on Vertex AI Model Garden](https://pantheon.corp.google.com/vertex-ai/publishers/google/model-garden/gemma3)
This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU.
## Updates
* [March 12th, 2025 🔥] Support Gemma v3. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-3/pytorch) and [Hugging Face](https://huggingface.co/models?other=gemma_torch)
* [June 26th, 2024] Support Gemma v2. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-2/pytorch) and Hugging Face
* [April 9th, 2024] Support CodeGemma. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/codegemma/pytorch) and [Hugging Face](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11)
* [April 5, 2024] Support Gemma v1.1. You can find the v1.1 checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch) and [Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
## Download Gemma model checkpoint
You can find the model checkpoints on Kaggle:
- [Gemma 3](https://www.kaggle.com/models/google/gemma-3/pyTorch)
- [Gemma 2](https://www.kaggle.com/models/google/gemma-2/pyTorch)
- [Gemma](https://www.kaggle.com/models/google/gemma/pyTorch)
Alternatively, you can find the model checkpoints on the Hugging Face Hub [here](https://huggingface.co/models?other=gemma_torch). To download the models, go the the model repository of the model of interest and click the `Files and versions` tab, and download the model and tokenizer files. For programmatic downloading, if you have `huggingface_hub` installed, you can also run:
```
huggingface-cli download google/gemma-3-4b-it-pytorch
```
The following model sizes are available:
- **Gemma 3**:
- **Text only**: 1b
- **Multimodal**: 4b, 12b, 27b_v3
- **Gemma 2**:
- **Text only**: 2b-v2, 9b, 27b
- **Gemma**:
- **Text only**: 2b, 7b
Note that you can choose between the 1B, 4B, 12B, and 27B variants.
```
VARIANT=<1b, 2b, 2b-v2, 4b, 7b, 9b, 12b, 27b, 27b_v3>
CKPT_PATH=<Insert ckpt path here>
```
## Try it free on Colab
Follow the steps at
[https://ai.google.dev/gemma/docs/pytorch_gemma](https://ai.google.dev/gemma/docs/pytorch_gemma).
## Try it out with PyTorch
Prerequisite: make sure you have setup docker permission properly as a non-root user.
```bash
sudo usermod -aG docker $USER
newgrp docker
```
### Build the docker image.
```bash
DOCKER_URI=gemma:${USER}
docker build -f docker/Dockerfile ./ -t ${DOCKER_URI}
```
### Run Gemma inference on CPU.
> NOTE: This is a multimodal example. Use a multimodal variant.
```bash
docker run -t --rm \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_multimodal.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Run Gemma inference on GPU.
> NOTE: This is a multimodal example. Use a multimodal variant.
```bash
docker run -t --rm \
--gpus all \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_multimodal.py \
--device=cuda \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}"
# add `--quant` for the int8 quantized model.
```
## Try It out with PyTorch/XLA
### Build the docker image (CPU, TPU).
```bash
DOCKER_URI=gemma_xla:${USER}
docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}
```
### Build the docker image (GPU).
```bash
DOCKER_URI=gemma_xla_gpu:${USER}
docker build -f docker/xla_gpu.Dockerfile ./ -t ${DOCKER_URI}
```
### Run Gemma inference on CPU.
> NOTE: This is a multimodal example. Use a multimodal variant.
```bash
docker run -t --rm \
--shm-size 4gb \
-e PJRT_DEVICE=CPU \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_xla.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Run Gemma inference on TPU.
Note: be sure to use the docker container built from `xla.Dockerfile`.
```bash
docker run -t --rm \
--shm-size 4gb \
-e PJRT_DEVICE=TPU \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_xla.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Run Gemma inference on GPU.
Note: be sure to use the docker container built from `xla_gpu.Dockerfile`.
```bash
docker run -t --rm --privileged \
--shm-size=16g --net=host --gpus all \
-e USE_CUDA=1 \
-e PJRT_DEVICE=CUDA \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_xla.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Tokenizer Notes
99 unused tokens are reserved in the pretrained tokenizer model to assist with more efficient training/fine-tuning. Unused tokens are in the string format of `<unused[0-97]>` with token id range of `[7-104]`.
```
"<unused0>": 7,
"<unused1>": 8,
"<unused2>": 9,
...
"<unused98>": 104,
```
## Disclaimer
This is not an officially supported Google product.
@@ -0,0 +1,107 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import random
import numpy as np
import torch
from absl import app, flags
from gemma import config
from gemma import model as gemma_model
# Define flags
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True)
flags.DEFINE_string('variant', '4b', 'Model variant.')
flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.')
flags.DEFINE_integer('seed', 12345, 'Random seed.')
flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
flags.DEFINE_string('prompt', 'What are large language models?', 'Input prompt for the model.')
# Define valid text only model variants
_VALID_MODEL_VARIANTS = ['2b', '2b-v2', '7b', '9b', '27b', '1b']
# Define valid devices
_VALID_DEVICES = ['cpu', 'cuda']
# Validator function for the 'variant' flag
def validate_variant(variant):
if variant not in _VALID_MODEL_VARIANTS:
raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}')
return True
# Validator function for the 'device' flag
def validate_device(device):
if device not in _VALID_DEVICES:
raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}')
return True
# Register the validator for the 'variant' flag
flags.register_validator('variant', validate_variant, message='Invalid model variant.')
# Register the validator for the 'device' flag
flags.register_validator('device', validate_device, message='Invalid device.')
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def main(_):
# Construct the model config.
model_config = config.get_model_config(FLAGS.variant)
model_config.dtype = "float32"
model_config.quant = FLAGS.quant
# Seed random.
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
torch.manual_seed(FLAGS.seed)
# Create the model and load the weights.
device = torch.device(FLAGS.device)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(FLAGS.ckpt)
model = model.to(device).eval()
print("Model loading done")
# Generate the response.
result = model.generate(FLAGS.prompt, device, output_len=FLAGS.output_len)
# Print the prompts and results.
print('======================================')
print(f'PROMPT: {FLAGS.prompt}')
print(f'RESULT: {result}')
print('======================================')
if __name__ == "__main__":
app.run(main)
# How to run this script:
# Example command (replace with your actual paths and values):
# python scripts/run.py --device=cpu --ckpt=/path/to/your/pytorch_checkpoint/model.ckpt --output_len=2 --prompt="The name of the capital of Italy is"
# Important:
# - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file.
# - Choose the correct --variant (model size).
# - Use --device=cuda if you have a GPU; otherwise, use --device=cpu.
@@ -0,0 +1,197 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import random
from absl import app
from absl import flags
import numpy as np
from PIL import Image
import torch
from gemma import config
from gemma import gemma3_model
# Define flags
FLAGS = flags.FLAGS
_CKPT = flags.DEFINE_string(
'ckpt', None, 'Path to the checkpoint file.', required=True
)
_VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.')
_DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
_OUTPUT_LEN = flags.DEFINE_integer(
'output_len', 10, 'Length of the output sequence.'
)
_SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.')
_QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
# Define valid multimodal model variants
_VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3']
# Define valid devices
_VALID_DEVICES = ['cpu', 'cuda']
# Validator function for the 'variant' flag
def validate_variant(variant):
if variant not in _VALID_MODEL_VARIANTS:
raise ValueError(
f'Invalid variant: {variant}. Valid variants are:'
f' {_VALID_MODEL_VARIANTS}'
)
return True
# Validator function for the 'device' flag
def validate_device(device):
if device not in _VALID_DEVICES:
raise ValueError(
f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}'
)
return True
# Register the validator for the 'variant' flag
flags.register_validator(
'variant', validate_variant, message='Invalid model variant.'
)
# Register the validator for the 'device' flag
flags.register_validator('device', validate_device, message='Invalid device.')
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def main(_):
# Construct the model config.
model_config = config.get_model_config(_VARIANT.value)
model_config.dtype = 'float32'
model_config.quant = _QUANT.value
image_paths = {"cow_in_beach": "scripts/images/cow_in_beach.jpg",
"lilly": "scripts/images/lilly.jpg",
"sunflower": "scripts/images/sunflower.JPG",
'golden_test_image': (
'scripts/images/test_image.jpg'
),
}
image = {}
for key in image_paths:
try:
image[key] = Image.open(image_paths[key]) # Open local file
image[key].show()
except IOError as e:
print(f"Error loading image: {e}")
exit()
# Seed random.
random.seed(_SEED.value)
np.random.seed(_SEED.value)
torch.manual_seed(_SEED.value)
# Create the model and load the weights.
device = torch.device(_DEVICE.value)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma3_model.Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(_CKPT.value)['model_state_dict'])
# model.load_weights(_CKPT.value)
model = model.to(device).eval()
print('Model loading done')
# Generate text only.
result = model.generate(
[
[
'<start_of_turn>user The capital of Italy'
' is?<end_of_turn>\n<start_of_turn>model'
],
[
'<start_of_turn>user What is your'
' purpose?<end_of_turn>\n<start_of_turn>model'
],
],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the results.
print('======================================')
print(f'Text only RESULT: {result}')
print('======================================')
# Generate golden Gemax test image.
result = model.generate(
[[
'<start_of_turn>user\n',
image['golden_test_image'],
'Caption this image. <end_of_turn>\n<start_of_turn>model',
]],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the result.
print('======================================')
print(f'Golden test image RESULT: {result}')
print('======================================')
# Generate text and image.
result = model.generate(
[[
'<start_of_turn>user\n',
image['cow_in_beach'],
(
'The name of the animal in the image is'
' <end_of_turn>\n<start_of_turn>model'
),
]],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the result.
print('======================================')
print(f'Single image RESULT: {result}')
print('======================================')
# Generate interleave text and multiple images.
result = model.generate(
[[
'<start_of_turn>user\nThis image',
image['lilly'],
'and this image',
image['sunflower'],
'are similar because? <end_of_turn>\n<start_of_turn>model',
]],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the result.
print('======================================')
print(f'Interleave images RESULT: {result}')
print('======================================')
if __name__ == '__main__':
app.run(main)
@@ -0,0 +1,267 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import os
import random
import socket
import sys
from typing import List, Union
import numpy as np
import torch
import torch.multiprocessing
from gemma.config import GemmaConfig, get_model_config
from gemma.model_xla import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import gemma.xla_model_parallel as xla_model_parallel
USE_CUDA = os.environ.get('USE_CUDA', False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
# Choose an available port.
with contextlib.closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
MASTER_PORT = str(s.getsockname()[1])
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def generate(
i: int,
model_config: GemmaConfig,
ckpt_path: str,
prompts: List[str],
output_lens: List[int],
temperatures: Union[List[float], None],
top_ps: List[float],
top_ks: List[int],
seed: int
):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device("cuda", local_rank)
torch.cuda.set_device(local_rank)
else:
device = xm.xla_device()
xm.set_rng_state(seed, device)
rank = xla_model_parallel.get_model_parallel_rank()
world_size = xla_model_parallel.get_model_parallel_world_size()
if rank > 0:
sys.stdout = open(os.devnull, 'w')
# build, load and compile model.
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config, world_size, rank, device)
model.load_weights(ckpt_path)
model = model.to(device).eval()
# create tokenizer.
tokenizer = Tokenizer(model_config.tokenizer)
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
batch_size = len(prompts)
if temperatures is not None:
assert batch_size == len(temperatures)
assert batch_size == len(top_ps)
assert batch_size == len(top_ks)
max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
assert max_seq_len <= model_config.max_position_embeddings
if model_config.num_key_value_heads < world_size:
assert world_size % model_config.num_key_value_heads == 0
n_local_heads = 1
else:
assert model_config.num_key_value_heads % world_size == 0
n_local_heads = model_config.num_key_value_heads // world_size
# build KV caches
kv_caches = []
for _ in range(model_config.num_hidden_layers):
k_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
v_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
kv_caches.append((k_cache, v_cache))
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
tokenizer.pad_id,
dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
input_token_ids_tensor = input_token_ids_tensor.to(device)
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
temperatures_tensor = None if not temperatures else torch.FloatTensor(temperatures).to(device)
top_ps_tensor = torch.FloatTensor(top_ps).to(device)
top_ks_tensor = torch.LongTensor(top_ks).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
if not USE_CUDA:
xm.mark_step()
# Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids, _ = model(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
output_index = output_index + 1
if not USE_CUDA:
xm.mark_step()
# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
output_lens[i]]
if tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(tokenizer.decode(trimmed_output))
for prompt, result in zip(prompts, results):
print('======================================')
print(f'PROMPT: {prompt}')
print(f'RESULT: {result}')
print('======================================')
def main(args):
model_config = get_model_config(args.variant)
model_config.quant = args.quant
prompts = [args.prompt]
n = len(prompts)
output_lengths = [args.output_len] * n
temperatures = [0.95] * n
top_ps = [1.0] * n
top_ks = [100] * n
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
torch.multiprocessing.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
else:
xmp.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "2b-v2", "7b", "9b", "27b"])
parser.add_argument("--output_len", type=int, default=4)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()
main(args)
Binary file not shown.