docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -14,6 +14,7 @@ Research corpus and implementation guidance for Google Gemma 4, based on product
|
||||
| `CORPUS_capabilities.md` | Modalities (vision, audio, video, tools), what it can/can't do | When scoping what Gemma 4 can handle |
|
||||
| `CORPUS_benchmarks.md` | Full benchmark table vs Gemma 3, arena scores, agentic scores | When comparing Gemma 4 to alternatives |
|
||||
| `CORPUS_tool_calling_format.md` | Native token format + JSON API format for function calling | When implementing tool calling |
|
||||
| `tooling/` | **Canonical upstream tooling** — real scripts, notebooks, model cards, and configs pulled from Google / HF / framework maintainers (147 files). Subdirs: `google-official/`, `huggingface/`, `inference-frameworks/`, `gemma-family/`, `fine-tuning/`. See `tooling/README.md` for index and findings that update the older `CORPUS_*` docs | When you need authoritative source material — model cards, chat templates, fine-tuning recipes, serving commands for vLLM / llama.cpp / MLX, or to scope a specialized sibling (ShieldGemma, EmbeddingGemma, etc.) |
|
||||
|
||||
## Source Projects
|
||||
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# Gemma 4 — Canonical Tooling Corpus
|
||||
|
||||
Actual scripts, notebooks, model cards, and configs downloaded from Google, Hugging Face, and the canonical framework maintainers. Populated 2026-04-18 by parallel research across five lanes. 147 files, ~14 MB.
|
||||
|
||||
**Triage: read the subdirectory README that matches your task, not this one.** This file is an index.
|
||||
|
||||
## Directory map
|
||||
|
||||
| Dir | What's there | When to open it |
|
||||
|-----|--------------|-----------------|
|
||||
| `google-official/` | `google-deepmind/gemma` JAX/Flax examples, `google/gemma_pytorch` scripts, `gemma.cpp` README + `gemma_api_server` docs, `google-gemma/cookbook` notebooks, official ai.google.dev HTML snapshots, Gemma 3 tech report PDF | Before trusting any non-Google source; when you need the authoritative prompt format or function-calling spec |
|
||||
| `huggingface/` | All 8 `google/gemma-4-*` model cards, chat-template `.jinja` files, `tokenizer_config.json` (with response-schema regex), transformers `gemma4/` source, official Gemma 4 Spaces `app.py`, HF launch blog posts | Before writing any transformers / `processor` integration; for the canonical chat-template handling |
|
||||
| `inference-frameworks/` | Comparison table across vLLM / llama.cpp / MLX / Keras-hub / TGI / Gemini API / Vertex AI. Real launch commands in `run_commands.sh`, 9 code snippets under `snippets/` | When picking a non-Ollama runtime; when you need audio/video input (Ollama doesn't expose it) |
|
||||
| `gemma-family/` | 12 per-variant briefs: ShieldGemma 2, CodeGemma, PaliGemma 2, RecurrentGemma, DataGemma, MedGemma, TxGemma, EmbeddingGemma, TranslateGemma, FunctionGemma, DolphinGemma, SignGemma + `index.md` | When scoping a project that needs a specialized sister model (embeddings, safety, vision-grounded, translation, tool routing) |
|
||||
| `fine-tuning/` | Unsloth Gemma 4 notebooks (text/vision/audio/GRPO), Axolotl Gemma 4 YAMLs (including 26B-A4B MoE), TRL reference scripts, Google cookbook fine-tune notebooks, `recipe-recommendation.md` with Seth's homelab-specific path | Before spending a dollar on cloud GPU or starting any Gemma 4 fine-tune |
|
||||
|
||||
## Findings that update / contradict the existing corpus
|
||||
|
||||
These are real gaps worth patching into `SYNTHESIS.md`, `GOTCHAS.md`, or `CORPUS_tool_calling_format.md`. Flagged here, not applied — the user asked for research, not a rewrite.
|
||||
|
||||
1. **Prompt-token format changed in Gemma 4.** Gemma 1/2/3 used `<start_of_turn>user ... <end_of_turn>`. Gemma 4 uses asymmetric pipe-brackets: `<|turn>user\n ... <turn|>`. Also new: `<|think|>`, `<|channel>thought...<channel|>`, `<|tool>`, `<|tool_call>`, `<|tool_response>` (+ inverses), `<|image>`, `<|audio>`, and string delimiter `<|"|>`. The existing `CORPUS_tool_calling_format.md` documents the tool tokens but doesn't reflect the turn-token change or the thinking/channel tokens. Canonical source: `huggingface/model-cards/gemma-4-31B-it-chat_template.jinja` and `google-official/docs/ai-google-dev_prompt_formatting_gemma4.html`.
|
||||
|
||||
2. **`google/gemma_pytorch` is abandoned for Gemma 4.** Last push 2025-05-30; the variants validator rejects Gemma 4 IDs. Anyone pointing at it as the PyTorch reference is wrong — use HF `transformers` or `google-deepmind/gemma` (JAX/Flax) instead.
|
||||
|
||||
3. **`gemma.cpp` ships a Gemini-API-compatible local HTTP server** (`gemma_api_server`, endpoint `POST /v1beta/models/<model>:generateContent`, SSE streaming). This is a Google-authored alternative to Ollama that speaks the real Gemini REST API — possibly the single most interesting discovery in this research pass. See `google-official/gemma-cpp/API_SERVER_README.md`.
|
||||
|
||||
4. **Transformers exposes `AutoModelForMultimodalLM` (new AutoClass)** — not `AutoModelForCausalLM`. It also exposes `processor.parse_response(..., response_schema=...)` driven from `tokenizer_config.json`, which replaces the hand-rolled regex in the current `CORPUS_tool_calling_format.md`. Pin: `transformers>=5.5.4`.
|
||||
|
||||
5. **Gemma 4 breaks Flash Attention.** FA2's max head_dim is 256, FA4's is 128, and Gemma 4's global head_dim is 512. Use SDP or Flex Attention. Axolotl hard-codes `sdp_attention: true` for Gemma 4. This belongs in `GOTCHAS.md`.
|
||||
|
||||
6. **The 26B variant is a MoE** — `gemma-4-26B-A4B` (A4B = 4B active per token). Quantization rules differ: Unsloth says use 16-bit LoRA, not 4-bit QLoRA, for acceptable quality. Axolotl's ScatterMoE + expert-LoRA config is the only tool validated for 4-bit MoE training. Worth a line in `CORPUS_ollama_variants.md`.
|
||||
|
||||
7. **No Gemma 4 technical report PDF exists yet** as of 2026-04-18. DeepMind repo says "Gemma 4 (Coming soon)". Gemma 3 report (downloaded at `google-official/tech-report/Gemma3Report.pdf`) remains the closest authoritative family citation.
|
||||
|
||||
8. **No `google/gemma-4-*` specialized siblings yet** — ShieldGemma, CodeGemma, PaliGemma, MedGemma, DataGemma are all still on Gemma 2 or 3 base. Historical lag is 3–6 months; expect siblings-on-4 mid-to-late 2026.
|
||||
|
||||
9. **No Gemma-4-specific TRL script in `huggingface/trl` yet.** HF blog says "fully supported," but the SFT/DPO/GRPO examples are still on Gemma 3 model IDs. Drop-in with `model_id` swap works. Only Gemma-4-dedicated TRL example today is `huggingface-gemma-recipes/carla_vlm_gemma.py` (VLM GRPO).
|
||||
|
||||
10. **HF Spaces `app.py` files are the shortest Gemma 4 inference examples** — Google and HF both use them as ref. See `huggingface/spaces/huggingface-projects_gemma-4-{31b,e4b}-it-app.py`.
|
||||
|
||||
## Immediate homelab plug-ins (from the gemma-family research)
|
||||
|
||||
- **EmbeddingGemma (308M)** — 100+ languages, Matryoshka to 128d. Drop-in upgrade from `nomic-embed-text` on both hosts.
|
||||
- **FunctionGemma (270M)** — cheap tool-router in front of `mortdecai:*` (latency win on hot path).
|
||||
- **PaliGemma 2 3B-448** — vision grounding with bbox output for AI_Visualizer / AI visualizer CT 167 alongside SDXL.
|
||||
- **TranslateGemma 4B** — useful for the family history agent (German/Polish sources).
|
||||
|
||||
## Source-url discipline
|
||||
|
||||
Every URL in the subdirectory READMEs was fetched and verified, not reconstructed from training. If a downloaded file is wrong, `git log` will show when it was pulled; the agent transcripts are the record of the source commit. Upstream repos can and do rename paths (see: `google-gemini/gemma-cookbook` → `google-gemma/cookbook`). Re-verify before citing externally.
|
||||
@@ -0,0 +1,281 @@
|
||||
# Gemma 4 Fine-Tuning Tooling — Index
|
||||
|
||||
Research captured 2026-04-18. All downloads verified against upstream repos.
|
||||
|
||||
## TL;DR
|
||||
|
||||
| Tool | Gemma 4 coverage | GPU floor (LoRA) | GPU floor (full FT) | Best at |
|
||||
|------|------------------|------------------|---------------------|---------|
|
||||
| **Unsloth** | Full parity — all 4 sizes, text/vision/audio/GRPO/RL | E2B: 8 GB, E4B: 17 GB, 26B A4B: ~40 GB, 31B QLoRA: 22 GB | Not recommended locally | **Fastest path**, Google-blessed, free Colab |
|
||||
| **TRL** | Partial — no `sft_gemma4.py` yet; `sft_gemma3.py` + `AutoModelForImageTextToText` works | Same as Unsloth w/ `load_in_4bit` | 2x H100 min for 31B | Research-grade control, DPO/GRPO/online RL, VLM GRPO on Gemma 4 (CARLA) |
|
||||
| **Axolotl** | **Native Gemma 4 configs shipped** (`examples/gemma4/`) | Single 5090 (32 GB) for 26B A4B QLoRA validated | >80 GB, "not tested" per README | Declarative YAML, multi-GPU FSDP, MoE expert LoRA |
|
||||
| **Google cookbook** | `docs/core/*` notebooks default to `google/gemma-4-E2B` | Depends on Colab tier | L4 (22 GB) for E4B QLoRA | Canonical baseline, paired with ai.google.dev docs |
|
||||
| **HF gemma-recipes** | Inference + one GRPO VLM script (CARLA) | E2B on T4 | — | VLM GRPO with tool-calling environment |
|
||||
| **Ollama** | Serves fine-tuned Gemma 4 via Modelfile `ADAPTER` | — | — | Final serving step |
|
||||
|
||||
**Recommendation for Seth: Unsloth.** See `recipe-recommendation.md`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Unsloth (`unsloth/`)
|
||||
|
||||
**Upstream:** `unslothai/notebooks`, `unslothai/unsloth`
|
||||
**License:** LGPL-3.0 (notebooks), Apache-2.0 (library)
|
||||
**Published Gemma 4 Dynamic quants:**
|
||||
- `unsloth/gemma-4-{E2B,E4B,31B,26B-A4B}-{,it}-unsloth-bnb-4bit` (dynamic 4-bit)
|
||||
- `unsloth/gemma-4-{E2B,E4B,31B,26B-A4B}-it-GGUF` (GGUF for inference)
|
||||
- Collection: https://huggingface.co/collections/unsloth/gemma-4
|
||||
|
||||
**Downloaded files (local paths under this directory):**
|
||||
- `unsloth/notebooks/Gemma4_(E2B)-Text.ipynb` — **canonical SFT notebook, T4-compatible**
|
||||
- `unsloth/notebooks/Gemma4_(E4B)-Text.ipynb` — 10 GB VRAM, higher accuracy
|
||||
- `unsloth/notebooks/Gemma4_(26B_A4B)-Text.ipynb` — MoE SFT (needs A100+)
|
||||
- `unsloth/notebooks/Gemma4_(31B)-Text.ipynb` — dense 31B SFT
|
||||
- `unsloth/notebooks/Gemma4_(E2B|E4B|26B_A4B|31B)-Vision.ipynb` — vision SFT w/ `UnslothVisionDataCollator`
|
||||
- `unsloth/notebooks/Gemma4_(E2B|E4B)-Audio.ipynb` — audio SFT (E2B/E4B only — 31B/26B have no audio encoder)
|
||||
- `unsloth/notebooks/Gemma4_(E2B)_GRPO.ipynb` — GRPO RL w/ Python reward funcs
|
||||
- `unsloth/notebooks/Gemma4_(E2B)_Reinforcement_Learning_{2048,Sudoku}_Game.ipynb` — game-playing RL
|
||||
- `unsloth/python_scripts/*.py` — same content as `.py` scripts (easier to grep/modify)
|
||||
- `unsloth/kaggle/Gemma4_(31B)-Text.ipynb`, `unsloth/kaggle/Gemma4_(E4B)-Text.ipynb` — Kaggle-flavored variants
|
||||
- `unsloth/docs/unsloth-README.md` — top-level Unsloth README
|
||||
|
||||
**Upstream URLs (useful to share):**
|
||||
- SFT E4B Colab: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma4_(E4B)-Text.ipynb
|
||||
- GRPO Colab: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma4_(E2B)_GRPO.ipynb
|
||||
- Unsloth Gemma 4 docs: https://unsloth.ai/docs/models/gemma-4/train
|
||||
|
||||
### Unsloth chat-template & masking detail (CRITICAL for Gemma 4)
|
||||
|
||||
Gemma 4 does **not** use Gemma 3's `<start_of_turn>` / `<end_of_turn>`. The new format is:
|
||||
|
||||
```
|
||||
<bos><|turn>user
|
||||
Hello<turn|>
|
||||
<|turn>model
|
||||
Hey there!<turn|>
|
||||
```
|
||||
|
||||
Unsloth's helper:
|
||||
```python
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(tokenizer, chat_template = "gemma-4") # literal "gemma-4", not "gemma4"
|
||||
```
|
||||
|
||||
Response-only masking (matches Unsloth's convention; everything *before* `response_part` is loss-masked):
|
||||
```python
|
||||
from unsloth.chat_templates import train_on_responses_only
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|turn>user\n",
|
||||
response_part = "<|turn>model\n",
|
||||
)
|
||||
```
|
||||
|
||||
`<bos>` gotcha: `apply_chat_template` prepends `<bos>`; Unsloth's `formatting_prompts_func` strips it with `.removeprefix('<bos>')` because the SFTTrainer's data collator adds its own — double `<bos>` silently degrades training.
|
||||
|
||||
**Tool tokens (`<|tool>`, `<|tool_call>`, `<|tool_response>`, `<|"|>`) are *not* masked** in Unsloth's default setup — they flow through as plain text inside user/assistant turns. If you're fine-tuning on tool-call data, include full `<|tool_call>...<tool_call|>` markup in the assistant `content` field; the template doesn't need a special `role=tool` branch.
|
||||
|
||||
### Unsloth MoE note
|
||||
|
||||
For 26B A4B (128 experts): Unsloth explicitly recommends **bf16/16-bit LoRA, NOT 4-bit QLoRA** ("MoE QLoRA not recommended, dense 31B is fine"). Their notebook uses `load_in_4bit = True` at >40 GB but the docs flag this as suboptimal.
|
||||
|
||||
---
|
||||
|
||||
## 2. TRL (`trl/`)
|
||||
|
||||
**Upstream:** `huggingface/trl`
|
||||
**License:** Apache-2.0
|
||||
|
||||
**Gemma 4-specific scripts:** NONE in `examples/scripts/` as of 2026-04-18. The canonical Gemma 4 TRL example lives in `huggingface-gemma-recipes/scripts/carla_vlm_gemma.py` (see next section).
|
||||
|
||||
**Closest-match Gemma 3 scripts downloaded (drop-in for Gemma 4 — change `model_id` to `google/gemma-4-*-it`, keep `AutoModelForImageTextToText`):**
|
||||
- `trl/sft_gemma3.py` — **use this as the Gemma 4 SFT template**. Pure text SFT (Codeforces-COTS).
|
||||
- `trl/sft_vlm_gemma3.py` — vision SFT template (uses `AutoModelForImageTextToText`, `all-linear` LoRA).
|
||||
- `trl/sft.py`, `trl/trl_scripts_sft.py` — the generic SFTTrainer wrappers.
|
||||
- `trl/sft_vlm.py` — model-agnostic VLM SFT.
|
||||
- `trl/dpo.py` — DPO (1-liner using TrlParser).
|
||||
- `trl/grpo_agent.py`, `trl/grpo_vlm.py` — GRPO with tool-calling environments.
|
||||
- `trl/sft_tiny_aya_tool_calling.py` — tool-calling SFT pattern.
|
||||
|
||||
**Chat template / masking detail:** TRL's `SFTTrainer` uses `tokenizer.apply_chat_template` end-to-end and delegates to the tokenizer's built-in Jinja template. For `google/gemma-4-*-it`, that template already produces `<|turn>user…<turn|>`. TRL supports `completion_only_loss` via the `SFTConfig(assistant_only_loss=True)` flag (TRL ≥ 0.22), which masks anything before the assistant turn — no manual `instruction_part` plumbing needed.
|
||||
|
||||
### Official HF blog says (verbatim):
|
||||
> "Gemma 4 is fully supported for fine-tuning with TRL. … we have prepared an example on how to fine-tune Gemma 4 with TRL on Vertex AI using SFT, to showcase how to extend the function calling capabilities, **whilst freezing both the vision and audio towers**."
|
||||
(see `huggingface-recipes/hf-blog-gemma4.md` §634-687)
|
||||
|
||||
---
|
||||
|
||||
## 3. Axolotl (`axolotl/`)
|
||||
|
||||
**Upstream:** `axolotl-ai-cloud/axolotl`, `examples/gemma4/`
|
||||
**License:** Apache-2.0
|
||||
**Gemma 4 status:** **Native support shipped**, day-one-class parity.
|
||||
|
||||
**Downloaded files:**
|
||||
- `axolotl/README.md` — official Axolotl Gemma 4 guide
|
||||
- `axolotl/31b-qlora.yaml` — 31B dense QLoRA, 1x80GB @ ~44 GB VRAM
|
||||
- `axolotl/31b-qlora-flex.yaml` — 31B dense QLoRA + Flex Attention, 1x80GB @ ~26 GB (40% less VRAM, 50% throughput cost)
|
||||
- `axolotl/26b-a4b-moe-qlora.yaml` — 26B MoE QLoRA + ScatterMoE expert-quantized + Expert-LoRA. Validated: 50 steps FineTome, loss 8.8→1.8, single RTX 5090 (32 GB), 21 GiB peak
|
||||
- `axolotl/e2b-vision-lora.yaml` — E2B vision LoRA with `freeze_mm_modules: true`
|
||||
|
||||
**Run command (from Axolotl README):**
|
||||
```bash
|
||||
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
|
||||
axolotl train examples/gemma4/31b-qlora.yaml
|
||||
axolotl train examples/gemma4/31b-qlora-flex.yaml
|
||||
```
|
||||
|
||||
### Axolotl chat template & masking detail
|
||||
|
||||
```yaml
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
```
|
||||
`chat_template: gemma4` (no dash — Axolotl's key is different from Unsloth's `"gemma-4"`). The template applies Gemma 4 turn tokens (`<|turn>user … <turn|>`). Masking is handled automatically by `type: chat_template` — only the assistant turn counts toward loss.
|
||||
|
||||
### Axolotl hard limitations for Gemma 4 (from their README)
|
||||
|
||||
- **Flash Attention OFF.** FA2 caps head_dim at 256; FA4 at 128; Gemma 4's `global_head_dim=512` exceeds both. **Use SDP or Flex Attention.** (`sdp_attention: true` in every yaml.)
|
||||
- **LoRA kernels OFF.** Due to Gemma 4's shared-KV layers (last N layers reuse K/V tensors): `lora_mlp_kernel: false`, `lora_qkv_kernel: false`, `lora_o_kernel: false`.
|
||||
- **`lora_target_linear` is incompatible** for multimodal. You MUST use `lora_target_modules` with the regex (see below) to restrict LoRA to the text decoder and NOT the vision/audio encoders.
|
||||
|
||||
Axolotl's canonical regex restricts LoRA to text layers only:
|
||||
```regex
|
||||
model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj
|
||||
```
|
||||
|
||||
For 26B A4B MoE, additionally target expert 3D tensors:
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Google Cookbook (`google-cookbook/`)
|
||||
|
||||
**Upstream:** `google-gemma/cookbook`, `docs/core/`
|
||||
**License:** Apache-2.0
|
||||
**Gemma 4 status:** The `docs/core/*.ipynb` fine-tuning notebooks default to `google/gemma-4-E2B` as `model_id` — they ARE the Gemma 4 path, despite generic filenames.
|
||||
|
||||
**Downloaded files:**
|
||||
- `google-cookbook/huggingface_text_finetune_qlora.ipynb` — **text-to-SQL QLoRA tutorial** (gretel-synthetic-text-to-sql dataset, `philschmid/gretel-synthetic-text-to-sql`). This is the one ai.google.dev links to as the "official" fine-tune path.
|
||||
- `google-cookbook/huggingface_text_full_finetune.ipynb` — full-weights fine-tune variant
|
||||
- `google-cookbook/huggingface_vision_finetune_qlora.ipynb` — vision QLoRA on product descriptions
|
||||
- `google-cookbook/lora_tuning.ipynb` — LoRA concepts tutorial
|
||||
- `google-cookbook/function-calling-gemma4.ipynb` — official Google function-calling notebook (not a fine-tune, but the authoritative reference for tool-call tokens)
|
||||
- `google-cookbook/Gemma_4_HDP_Agentic_Security.ipynb` + `Gemma_4_HDP_README.md` — full-app fine-tune example (agentic security)
|
||||
|
||||
**Upstream URLs:**
|
||||
- https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora
|
||||
- https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
|
||||
- https://ai.google.dev/gemma/docs/capabilities/text/function-calling-gemma4
|
||||
|
||||
### Google cookbook chat template & masking detail (VERY IMPORTANT)
|
||||
|
||||
The cookbook notebooks use TRL's `SFTTrainer` with standard `messages` list (`role`/`content`) — chat-template is applied automatically by the tokenizer's built-in Jinja. No manual `instruction_part`/`response_part`.
|
||||
|
||||
**The non-obvious detail** is the `LoraConfig`:
|
||||
```python
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=16, lora_dropout=0.05, r=16, bias="none",
|
||||
target_modules="all-linear",
|
||||
task_type="CAUSAL_LM",
|
||||
modules_to_save=["lm_head", "embed_tokens"], # NOTE
|
||||
ensure_weight_tying=True, # NOTE
|
||||
)
|
||||
```
|
||||
`modules_to_save=["lm_head","embed_tokens"]` + `ensure_weight_tying=True` is required because **Gemma 4 introduced new special tokens (`<|turn>`, `<|tool>`, `<|tool_call>`, `<|tool_response>`, `<|"|>`) that need their embeddings to be trainable in a fine-tune.** PEFT 0.15+ added `ensure_weight_tying` specifically for this case. Skipping it causes the adapter to see frozen random embeddings for the new tokens and training silently underperforms.
|
||||
|
||||
For vision, Google's cookbook uses plain `target_modules="all-linear"` (NO `exclude_modules`) — meaning it *does* train LoRA adapters on the vision tower. This is a different tradeoff from Axolotl (`freeze_mm_modules: true`) and from TRL's CARLA recipe (`exclude_modules=["vision_tower", "multi_modal_projector"]`). Pick based on whether your task needs the vision encoder to adapt (e.g., new image domain) or just the text decoder (most cases).
|
||||
|
||||
---
|
||||
|
||||
## 5. HuggingFace gemma-recipes (`huggingface-recipes/`)
|
||||
|
||||
**Upstream:** `huggingface/huggingface-gemma-recipes`
|
||||
**License:** Apache-2.0
|
||||
|
||||
**Downloaded files:**
|
||||
- `huggingface-recipes/carla_vlm_gemma.py` — **The canonical TRL + Gemma 4 example.** GRPO VLM training in a CARLA driving environment with tool calls. Shows `exclude_modules=["vision_tower", "multi_modal_projector"]`, `chat_template_kwargs={"enable_thinking": False}`, `max_tool_calling_iterations=10`.
|
||||
- `huggingface-recipes/Gemma4_(E2B)-Multimodal.ipynb` — **inference-only** multimodal demo (vision, video, audio, function calling, object detection). Not a fine-tune but necessary reference for the input format the training data must match.
|
||||
- `huggingface-recipes/README.md` — HF's top-level recipes index
|
||||
- `huggingface-recipes/hf-blog-gemma4.md` — the HF blog post's raw markdown (§630-707 is the fine-tuning section)
|
||||
|
||||
**Run command for the CARLA VLM RL example:**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/trl.git
|
||||
python examples/scripts/openenv/carla_vlm_gemma.py \
|
||||
--env-urls https://sergiopaniego-carla-env.hf.space https://sergiopaniego-carla-env-2.hf.space \
|
||||
--model google/gemma-4-E2B-it
|
||||
```
|
||||
|
||||
**Known gap:** HF's gemma-recipes repo has *fine-tuning* notebooks for Gemma 3 and Gemma 3n (free T4 Colab) but **no pure-SFT Gemma 4 fine-tuning notebook yet** — the Gemma 4 Colab is inference only. Their blog points users to Unsloth Studio for the easy path.
|
||||
|
||||
---
|
||||
|
||||
## 6. Ollama / llama.cpp LoRA serving (`ollama-llamacpp/`)
|
||||
|
||||
**Downloaded:** `ollama-llamacpp/ollama-import-lora.md` — distilled from https://docs.ollama.com/import (2026-04-18 fetch).
|
||||
|
||||
**Short answer:** Yes, you can serve a Gemma 4 LoRA via Ollama. Two paths:
|
||||
|
||||
1. **Merge then serve (simpler, recommended):** `model.save_pretrained_merged("out", tokenizer, save_method="merged_16bit")` → `llama.cpp/convert_hf_to_gguf.py` → `llama.cpp/quantize` to Q4_K_M → `ollama create mymodel -f Modelfile` with `FROM ./gemma4-mortdecai.gguf`.
|
||||
2. **Adapter-only serve:** `llama.cpp/convert_lora_to_gguf.py` on the PEFT directory → Modelfile with `FROM gemma4:e4b-it-q8_0` + `ADAPTER ./adapter.gguf`.
|
||||
|
||||
Ollama's docs list supported architectures as Llama/Mistral/Gemma 1-2 — Gemma 4 isn't *explicitly* listed, but llama.cpp has day-one Gemma 4 support and in practice the path works. (Vision-adapter serving via Ollama is still a grey area.)
|
||||
|
||||
---
|
||||
|
||||
## 7. Datasets the canonical tutorials pair with Gemma 4
|
||||
|
||||
| Tutorial | Dataset | Format | Notes |
|
||||
|----------|---------|--------|-------|
|
||||
| Unsloth Gemma4 E4B Text | `mlabonne/FineTome-100k` | ShareGPT-style `conversations` field | Also the Axolotl default |
|
||||
| Unsloth Gemma4 GRPO | Synthetic kernel-optimization prompts in-notebook | Python reward funcs | RL w/ `function_works` / `check_only_stdlib_imports` |
|
||||
| Unsloth Gemma4 Vision | `unsloth/LaTeX_OCR` | HF image-text pairs | Demonstrates `UnslothVisionDataCollator` |
|
||||
| Google cookbook text QLoRA | `philschmid/gretel-synthetic-text-to-sql` | chat `messages` list | Google's "official" demo dataset for Gemma 4 |
|
||||
| Google cookbook vision QLoRA | `philschmid/amazon-product-descriptions-vlm` | image + text pairs | Product-description generation |
|
||||
| Axolotl Gemma 4 (all sizes) | `mlabonne/FineTome-100k` | `type: chat_template` | Validated in axolotl README |
|
||||
| Axolotl E2B vision LoRA | `HuggingFaceH4/llava-instruct-mix-vsft` | vision-language SFT | Same as HF's VLM template |
|
||||
| TRL sft_gemma3 (transfers) | `open-r1/codeforces-cots` | `messages` list | Chain-of-thought coding |
|
||||
| TRL carla_vlm_gemma (Gemma 4 VLM GRPO) | CARLA simulator (live) | environment rollouts | Multimodal tool responses |
|
||||
|
||||
No one uses Alpaca or UltraChat as the canonical Gemma 4 pair. **FineTome-100k is the unofficial standard** — both Unsloth and Axolotl default to it.
|
||||
|
||||
---
|
||||
|
||||
## 8. Chat-template-and-masking matrix (the debugging cheat sheet)
|
||||
|
||||
| Framework | chat_template key | Turn tokens | Response masking API | BOS handling |
|
||||
|-----------|-------------------|-------------|----------------------|--------------|
|
||||
| Unsloth | `"gemma-4"` | `<|turn>role\n...<turn|>` | `train_on_responses_only(instruction_part="<|turn>user\n", response_part="<|turn>model\n")` | Strip `<bos>` manually with `.removeprefix('<bos>')` before passing to trainer |
|
||||
| TRL | tokenizer's built-in Jinja (no key needed) | same | `SFTConfig(assistant_only_loss=True)` | Tokenizer handles automatically |
|
||||
| Axolotl | `chat_template: gemma4` (no dash) | same | automatic via `type: chat_template` | Automatic |
|
||||
| Google cookbook | tokenizer built-in Jinja | same | automatic via `SFTTrainer` + `messages` | Automatic |
|
||||
|
||||
Tool tokens (`<|tool>`, `<|tool_call>`, `<|tool_response>`, `<|"|>`) ride inside message content — none of the frameworks mask them specially, and none provide a `role="tool"` branch in the default template. If you're training tool-call data, put the complete `<|tool_call>call:{...}<tool_call|>` block in the assistant message `content`.
|
||||
|
||||
Also: **all Gemma 4 fine-tunes should `modules_to_save=["lm_head","embed_tokens"]` + `ensure_weight_tying=True`** in LoraConfig if you're using PEFT directly, because the new special-token embeddings need to be trainable. Unsloth and Axolotl handle this for you; naïve TRL + PEFT scripts do NOT by default.
|
||||
|
||||
---
|
||||
|
||||
## What's NOT here (and why)
|
||||
|
||||
- **Kaggle/Colab free-tier notebooks as a separate category** — the Unsloth notebooks *are* the free-tier notebooks. E2B Text runs on a free T4; 31B/26B-A4B need A100 Colab Pro. I pulled 2 Kaggle-flavored variants to `unsloth/kaggle/` for completeness.
|
||||
- **Google's DeepMind JAX/Flax Gemma 4 fine-tune script** — Google's DeepMind-gemma repo ships inference/reference code, not a SFT script. Google's *canonical* fine-tune path is the HF+TRL notebook in `google-gemma/cookbook` (above), NOT JAX. If you want JAX, see the archived `.archive/Gemma/[Gemma_1]Finetune_distributed.ipynb` pattern — not ported to Gemma 4.
|
||||
- **Full-weights 31B fine-tuning commands** — Axolotl's README says "heavy and has not been tested." Skip unless Seth rents an 8×H100 pod.
|
||||
- **Prompt engineering / inference-only notebooks** — per scope.
|
||||
|
||||
## See also
|
||||
|
||||
- `recipe-recommendation.md` — which tool Seth should actually use for his homelab, with the exact command.
|
||||
- `../../GOTCHAS.md` §"Fine-Tuning Ecosystem Issues" — day-one issues (required `mm_token_type_ids` field, Gemma4ClippableLinear PEFT issue, E2B/E4B training loss 13-15 being normal).
|
||||
- `../../CORPUS_tool_calling_format.md` — the 6 tool-calling special tokens.
|
||||
@@ -0,0 +1,93 @@
|
||||
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
|
||||
#
|
||||
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
|
||||
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
|
||||
#
|
||||
# Key notes:
|
||||
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
|
||||
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
|
||||
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
|
||||
|
||||
base_model: google/gemma-4-26B-A4B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
torch_compile: true
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-26b-a4b-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
# using regex to match only the text decoder attention projections.
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA2 not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -0,0 +1,71 @@
|
||||
base_model: google/gemma-4-31B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
torch_compile: true
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-31b-qlora-flex
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
flex_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -0,0 +1,69 @@
|
||||
base_model: google/gemma-4-31B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
torch_compile: false
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-31b-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
# using regex to match only the text decoder attention projections.
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -0,0 +1,60 @@
|
||||
# Finetune Google's Gemma 4 with Axolotl
|
||||
|
||||
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
|
||||
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
|
||||
|
||||
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
|
||||
axolotl train examples/gemma4/31b-qlora.yaml
|
||||
|
||||
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
|
||||
axolotl train examples/gemma4/31b-qlora-flex.yaml
|
||||
```
|
||||
|
||||
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
|
||||
|
||||
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
|
||||
|
||||
## Flex Attention
|
||||
|
||||
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
|
||||
|
||||
```yaml
|
||||
torch_compile: true
|
||||
flex_attention: true
|
||||
```
|
||||
|
||||
This works for both the MoE and Dense model.
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
|
||||
- **LoRA kernels**: Not supported due to KV-sharing layers.
|
||||
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -0,0 +1,62 @@
|
||||
# Gemma 4 E2B Vision LoRA
|
||||
#
|
||||
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
|
||||
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
|
||||
|
||||
base_model: google/gemma-4-E2B-it
|
||||
processor_type: AutoProcessor
|
||||
freeze_mm_modules: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
# Required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:100]
|
||||
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gemma4-e2b-vision-lora
|
||||
|
||||
adapter: lora
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
# Target language model only — vision encoder is frozen via freeze_mm_modules
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
@@ -0,0 +1,526 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "colab-badge"
|
||||
},
|
||||
"source": [
|
||||
"<table align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/apps/Gemma_4_HDP_Agentic_Security/Gemma_4_HDP_Agentic_Security.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "byline"
|
||||
},
|
||||
"source": [
|
||||
"# Securing Gemma 4 Agentic Workflows with HDP\n",
|
||||
"\n",
|
||||
"**Author:** Asiri Dalugoda, Helixar Limited ([@asiridalugoda](https://github.com/asiridalugoda)) | [helixar.ai](https://helixar.ai)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gpu-instructions"
|
||||
},
|
||||
"source": [
|
||||
"## Before you begin\n",
|
||||
"\n",
|
||||
"This notebook requires a GPU runtime. To enable GPU in Colab:\n",
|
||||
"1. Go to **Runtime → Change runtime type**\n",
|
||||
"2. Set **Hardware accelerator** to **GPU** (T4 is sufficient for E4B)\n",
|
||||
"3. Click **Save**\n",
|
||||
"\n",
|
||||
"You will also need a **Hugging Face token** to download Gemma 4 (gated model):\n",
|
||||
"1. Go to [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
|
||||
"2. Create a token with **Read** access\n",
|
||||
"3. Accept the Gemma 4 model license at [huggingface.co/google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)\n",
|
||||
"4. Run the cell below to authenticate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hf-login"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "overview"
|
||||
},
|
||||
"source": [
|
||||
"# Securing Gemma 4 Agentic Workflows with HDP\n",
|
||||
"\n",
|
||||
"**Human Delegation Provenance (HDP)** is an open protocol that adds a cryptographic chain-of-custody to AI agent function calls — ensuring every tool invocation can be traced back to an authorized human principal.\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to integrate HDP with Gemma 4's native function-calling capability to:\n",
|
||||
"\n",
|
||||
"- **Verify** that Gemma 4's function calls were authorized by a human principal before execution\n",
|
||||
"- **Classify** actions by irreversibility (read-only → irreversible → physical actuation)\n",
|
||||
"- **Block** unauthorized or out-of-scope tool calls at the middleware layer\n",
|
||||
"- **Audit** every decision with a pre-execution log\n",
|
||||
"\n",
|
||||
"This is particularly relevant for Gemma 4 deployments on edge devices (Jetson Nano, Raspberry Pi) where the model may be directing physical actuators offline with no out-of-band authorization check.\n",
|
||||
"\n",
|
||||
"**References:**\n",
|
||||
"- HDP IETF draft: [draft-helixar-hdp-agentic-delegation-00](https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/)\n",
|
||||
"- HDP-P (physical AI agents): [DOI 10.5281/ZENODO.19332440](https://doi.org/10.5281/ZENODO.19332440)\n",
|
||||
"- Helixar: [helixar.ai](https://helixar.ai)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b3600ee25c8e"
|
||||
},
|
||||
"source": [
|
||||
"## Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "7a80251f52b3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q transformers torch cryptography"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ed80fe18f255"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Download the middleware\n",
|
||||
"!wget -q https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/apps/Gemma_4_HDP_Agentic_Security/hdp_middleware.py\n",
|
||||
"\n",
|
||||
"from hdp_middleware import (\n",
|
||||
" HDPDelegationToken,\n",
|
||||
" HDPMiddleware,\n",
|
||||
" IrreversibilityClass,\n",
|
||||
" DEFAULT_TOOL_CLASS_MAP,\n",
|
||||
")\n",
|
||||
"from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey\n",
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e88bdc7b7265"
|
||||
},
|
||||
"source": [
|
||||
"## 1. Load Gemma 4\n",
|
||||
"\n",
|
||||
"We use the 4B Effective model for this demo. For production agentic deployments, the 26B MoE or 31B Dense models are recommended."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1e4e7779806d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"# For edge/robotics use cases: swap to google/gemma-4-E2B-it\n",
|
||||
"MODEL_ID = \"google/gemma-4-E4B-it\"\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\n",
|
||||
" \"text-generation\",\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "d91e36cfb0b2"
|
||||
},
|
||||
"source": [
|
||||
"## 2. Define Tools\n",
|
||||
"\n",
|
||||
"Gemma 4 uses structured JSON function-calling. We define a tool set spanning different IrreversibilityClasses to demonstrate the middleware's classification behaviour."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1becdb52e7f8"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TOOLS = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"get_weather\",\n",
|
||||
" \"description\": \"Get the current weather for a location.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\"type\": \"string\", \"description\": \"City name\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"send_email\",\n",
|
||||
" \"description\": \"Send an email to a recipient.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"to\": {\"type\": \"string\"},\n",
|
||||
" \"subject\": {\"type\": \"string\"},\n",
|
||||
" \"body\": {\"type\": \"string\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"to\", \"subject\", \"body\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"delete_file\",\n",
|
||||
" \"description\": \"Permanently delete a file by path.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"path\": {\"type\": \"string\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"path\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"actuate_robot_arm\",\n",
|
||||
" \"description\": \"Command a robot arm to move to a target position.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"joint_angles\": {\"type\": \"array\", \"items\": {\"type\": \"number\"}},\n",
|
||||
" \"force_limit_n\": {\"type\": \"number\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"joint_angles\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Tools indexed by name for lookup\n",
|
||||
"TOOL_REGISTRY = {t[\"name\"]: t for t in TOOLS}\n",
|
||||
"print(f\"Registered {len(TOOLS)} tools\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "722948b00a92"
|
||||
},
|
||||
"source": [
|
||||
"## 3. Issue an HDP Delegation Token\n",
|
||||
"\n",
|
||||
"The human principal generates an Ed25519 keypair and issues an HDT that specifies:\n",
|
||||
"- Which tools the agent is permitted to call\n",
|
||||
"- The maximum IrreversibilityClass the agent can act on\n",
|
||||
"- The token's lifetime"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "b0622c68dfa5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Human principal generates their signing keypair\n",
|
||||
"# In production: loaded from secure key storage (HSM, OS keychain, etc.)\n",
|
||||
"principal_private_key = Ed25519PrivateKey.generate()\n",
|
||||
"principal_public_key = principal_private_key.public_key()\n",
|
||||
"\n",
|
||||
"# Issue an HDT authorizing the Gemma 4 agent to call weather queries\n",
|
||||
"# and send emails (Class 0 and Class 2), but NOT delete files or actuate hardware\n",
|
||||
"token = HDPDelegationToken.issue(\n",
|
||||
" principal_id=\"alice@example.com\",\n",
|
||||
" agent_id=\"gemma4-agent-01\",\n",
|
||||
" scope=[\"get_weather\", \"send_email\"],\n",
|
||||
" max_class=IrreversibilityClass.CLASS_2,\n",
|
||||
" ttl_seconds=3600,\n",
|
||||
" private_key=principal_private_key,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(json.dumps(token.to_dict(), indent=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e206f950f4bc"
|
||||
},
|
||||
"source": [
|
||||
"## 4. Initialise the HDP Middleware\n",
|
||||
"\n",
|
||||
"The middleware takes the principal's **public key** only — it verifies but cannot issue tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "e24676f528bf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"audit_log = []\n",
|
||||
"\n",
|
||||
"# Confirmation callback for Class 2 (irreversible) actions.\n",
|
||||
"# In production: this would invoke a push notification, SMS OTP,\n",
|
||||
"# or hardware confirmation device to the human principal.\n",
|
||||
"def require_human_confirmation(tool_name: str, parameters: dict) -> bool:\n",
|
||||
" print(f\"\\n⚠️ Class 2 action requested: {tool_name}\")\n",
|
||||
" print(f\" Parameters: {json.dumps(parameters, indent=4)}\")\n",
|
||||
" response = input(\" Confirm? [y/N]: \").strip().lower()\n",
|
||||
" return response == \"y\"\n",
|
||||
"\n",
|
||||
"middleware = HDPMiddleware(\n",
|
||||
" public_key=principal_public_key,\n",
|
||||
" tool_class_map=DEFAULT_TOOL_CLASS_MAP,\n",
|
||||
" confirmation_callback=require_human_confirmation,\n",
|
||||
" audit_log=audit_log,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"HDP middleware initialised.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "72d56542eba0"
|
||||
},
|
||||
"source": [
|
||||
"## 5. Gemma 4 Function Call → HDP Gate → Tool Execution\n",
|
||||
"\n",
|
||||
"This is the core integration pattern. Every function call Gemma 4 generates is passed through `middleware.gate()` before being forwarded to tool execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "da20bc191e71"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Simulated Gemma 4 function call outputs\n",
|
||||
"# In production these come from parsing Gemma 4's structured JSON output\n",
|
||||
"gemma_function_calls = [\n",
|
||||
" # ✅ Should ALLOW — Class 0, in scope\n",
|
||||
" {\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}},\n",
|
||||
"\n",
|
||||
" # ⚠️ Should CONFIRM then ALLOW — Class 2, in scope\n",
|
||||
" {\"name\": \"send_email\", \"parameters\": {\n",
|
||||
" \"to\": \"bob@example.com\",\n",
|
||||
" \"subject\": \"Weekly report\",\n",
|
||||
" \"body\": \"Please find attached.\"\n",
|
||||
" }},\n",
|
||||
"\n",
|
||||
" # ❌ Should BLOCK — Class 2, NOT in HDT scope\n",
|
||||
" {\"name\": \"delete_file\", \"parameters\": {\"path\": \"/data/important.csv\"}},\n",
|
||||
"\n",
|
||||
" # ❌ Should BLOCK — Class 3, physical actuation\n",
|
||||
" {\"name\": \"actuate_robot_arm\", \"parameters\": {\n",
|
||||
" \"joint_angles\": [0.0, -1.57, 0.0, -1.57, 0.0, 0.0],\n",
|
||||
" \"force_limit_n\": 50.0\n",
|
||||
" }},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(\"HDP VERIFICATION RESULTS\")\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
"for call in gemma_function_calls:\n",
|
||||
" result = middleware.gate(call, token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "be0d0dd05bce"
|
||||
},
|
||||
"source": [
|
||||
"## 6. Audit Log\n",
|
||||
"\n",
|
||||
"Every decision is logged pre-execution. This is the HDP audit trail — a cryptographically linked record of what was authorized, by whom, and when."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "e6dbab6d88d1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"\\nAUDIT LOG\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"for i, entry in enumerate(audit_log):\n",
|
||||
" status = \"✅ ALLOWED\" if entry.allowed else \"❌ BLOCKED\"\n",
|
||||
" print(f\"{i+1}. {status} | {entry.tool_name} | {entry.action_class.name} | {entry.reason}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bcadcb7040db"
|
||||
},
|
||||
"source": [
|
||||
"## 7. Token Expiry and Scope Violation Demo\n",
|
||||
"\n",
|
||||
"Demonstrate that expired tokens and out-of-scope calls are blocked regardless of the action class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "deb2e3b6b20e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# Issue a token that's already expired\n",
|
||||
"expired_token = HDPDelegationToken.issue(\n",
|
||||
" principal_id=\"alice@example.com\",\n",
|
||||
" agent_id=\"gemma4-agent-01\",\n",
|
||||
" scope=[\"get_weather\"],\n",
|
||||
" max_class=IrreversibilityClass.CLASS_0,\n",
|
||||
" ttl_seconds=-1, # expired immediately\n",
|
||||
" private_key=principal_private_key,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Testing expired token:\")\n",
|
||||
"middleware.gate({\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}}, expired_token)\n",
|
||||
"\n",
|
||||
"print(\"\\nTesting call outside HDT scope:\")\n",
|
||||
"middleware.gate({\"name\": \"delete_file\", \"parameters\": {\"path\": \"/etc/passwd\"}}, token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b8f4acddb6fa"
|
||||
},
|
||||
"source": [
|
||||
"## 8. Edge / Robotics Deployment (HDP-P)\n",
|
||||
"\n",
|
||||
"For Gemma 4 E2B/E4B running on Jetson Nano or Raspberry Pi and directing physical actuators, use the HDP-P extension. The key additions are:\n",
|
||||
"\n",
|
||||
"- **Embodiment context** — bind the token to a specific hardware ID\n",
|
||||
"- **Policy attestation** — hash the deployed model weights into the token\n",
|
||||
"- **Fleet delegation constraints** — prevent lateral movement across robot fleet\n",
|
||||
"- **Pre-execution logging** — write audit records *before* actuator commands are issued\n",
|
||||
"\n",
|
||||
"See the [HDP-P specification](https://doi.org/10.5281/ZENODO.19332440) for the full EDT extension structure."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fcf7b451d175"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Minimal HDP-P Embodied Delegation Token (EDT) extension example\n",
|
||||
"# This shows how to attach physical constraints to an HDT\n",
|
||||
"\n",
|
||||
"hdp_p_extension = {\n",
|
||||
" \"hdp-p\": {\n",
|
||||
" \"version\": \"0.1\",\n",
|
||||
" \"embodiment\": {\n",
|
||||
" \"type\": \"mobile\",\n",
|
||||
" \"platform\": \"raspberry-pi-5\",\n",
|
||||
" \"hardware_id\": \"rpi-serial-XXXX\", # TPM-attested in production\n",
|
||||
" \"workspace\": \"lab-zone-a\"\n",
|
||||
" },\n",
|
||||
" \"action_scope\": {\n",
|
||||
" \"permitted_actions\": [\"move_base\", \"read_sensor\"],\n",
|
||||
" \"excluded_zones\": [\"human-workspace\"],\n",
|
||||
" \"force_limit_n\": 10.0,\n",
|
||||
" \"max_velocity_ms\": 0.5\n",
|
||||
" },\n",
|
||||
" \"irreversibility\": {\n",
|
||||
" \"max_class\": 1, # Class 1 max for this token\n",
|
||||
" \"class2_requires_confirmation\": True,\n",
|
||||
" \"class3_prohibited\": True\n",
|
||||
" },\n",
|
||||
" \"policy_attestation\": {\n",
|
||||
" \"policy_hash\": \"sha256:abc123...\", # SHA-256 of deployed model weights\n",
|
||||
" \"training_run_id\": \"gemma4-e2b-it\",\n",
|
||||
" \"sim_validated\": True\n",
|
||||
" },\n",
|
||||
" \"delegation_scope\": {\n",
|
||||
" \"fleet_delegation_permitted\": False, # No lateral movement\n",
|
||||
" \"max_delegation_depth\": 0\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"print(\"HDP-P EDT extension structure:\")\n",
|
||||
"print(json.dumps(hdp_p_extension, indent=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b0af7c701dfc"
|
||||
},
|
||||
"source": [
|
||||
"## Summary\n",
|
||||
"\n",
|
||||
"| Layer | What it solves | Tool |\n",
|
||||
"|---|---|---|\n",
|
||||
"| Gemma 4 function calling | Model generates structured tool calls | `pipeline(\"text-generation\")` |\n",
|
||||
"| HDP middleware | Was this call authorized by a human? | `HDPMiddleware.gate()` |\n",
|
||||
"| HDP-P EDT extension | Is this physical action within delegated bounds? | `hdp_p_extension` |\n",
|
||||
"| Audit log | Pre-execution record of every decision | `audit_log` |\n",
|
||||
"\n",
|
||||
"The full HDP specification (IETF draft), HDP-P companion paper, TypeScript SDK, and Python bindings are available at:\n",
|
||||
"\n",
|
||||
"- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/\n",
|
||||
"- **HDP-P paper:** https://doi.org/10.5281/ZENODO.19332440\n",
|
||||
"- **GitHub:** https://github.com/Helixar-AI\n",
|
||||
"- **Site:** https://helixar.ai"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "Gemma_4_HDP_Agentic_Security.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
# Gemma 4 + HDP: Securing Agentic Function Calls
|
||||
|
||||
This example demonstrates how to integrate the **Human Delegation Provenance (HDP)** protocol with **Gemma 4's native function-calling** to cryptographically verify that every tool invocation was authorized by a human principal before execution.
|
||||
|
||||
## The problem
|
||||
|
||||
Gemma 4 is purpose-built for agentic workflows. Its native function-calling lets it autonomously call tools and APIs across multi-step plans — on anything from a cloud workstation to a Raspberry Pi running a robot offline.
|
||||
|
||||
This creates a gap: when Gemma 4 generates a function call, there is no verifiable record that a human principal authorized that specific action. An injected prompt, a compromised system prompt, or a lateral pivot from another agent can trigger function calls that are indistinguishable from legitimate requests at the tool interface.
|
||||
|
||||
HDP closes this gap.
|
||||
|
||||
## What HDP does
|
||||
|
||||
HDP (IETF draft: `draft-helixar-hdp-agentic-delegation-00`) provides:
|
||||
|
||||
- **Ed25519-signed Delegation Tokens (HDTs)** issued by a human principal
|
||||
- **Scope constraints** — which tools the agent is permitted to call
|
||||
- **Irreversibility classification** (Class 0–3) — from read-only to physical actuation
|
||||
- **Pre-execution verification** — the middleware gate runs *before* any tool executes
|
||||
- **Audit log** — a tamper-evident record of every authorization decision
|
||||
|
||||
For Gemma 4 on **edge devices directing physical actuators** (Jetson Nano, Raspberry Pi + robot arm), the HDP-P companion specification adds embodiment constraints, policy attestation, and fleet delegation controls.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| `Gemma_4_HDP_Agentic_Security.ipynb` | Full walkthrough notebook — load Gemma 4, issue tokens, gate function calls |
|
||||
| `hdp_middleware.py` | Drop-in middleware — `HDPMiddleware.gate()` wraps any Gemma 4 tool executor |
|
||||
|
||||
## Quick start
|
||||
|
||||
```python
|
||||
from hdp_middleware import HDPDelegationToken, HDPMiddleware, IrreversibilityClass
|
||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
|
||||
|
||||
# Human principal issues a delegation token
|
||||
private_key = Ed25519PrivateKey.generate()
|
||||
token = HDPDelegationToken.issue(
|
||||
principal_id="alice@example.com",
|
||||
agent_id="gemma4-agent-01",
|
||||
scope=["get_weather", "send_email"],
|
||||
max_class=IrreversibilityClass.CLASS_2,
|
||||
ttl_seconds=3600,
|
||||
private_key=private_key,
|
||||
)
|
||||
|
||||
# Middleware verifies every Gemma 4 function call before execution
|
||||
middleware = HDPMiddleware(public_key=private_key.public_key())
|
||||
|
||||
result = middleware.gate(
|
||||
function_call={"name": "send_email", "parameters": {"to": "bob@example.com", ...}},
|
||||
token=token,
|
||||
)
|
||||
|
||||
if result.allowed:
|
||||
execute_tool(function_call)
|
||||
```
|
||||
|
||||
## Irreversibility classes
|
||||
|
||||
| Class | Definition | Authorization |
|
||||
|---|---|---|
|
||||
| 0 | Fully reversible — reads, queries | HDT sufficient |
|
||||
| 1 | Reversible with effort — writes, moves | HDT sufficient |
|
||||
| 2 | Irreversible — send, delete, publish | HDT + principal confirmation |
|
||||
| 3 | Irreversible + potentially harmful — physical actuation | Dual-principal required (HDP-P) |
|
||||
|
||||
## References
|
||||
|
||||
- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/
|
||||
- **Zenodo DOI:** https://doi.org/10.5281/zenodo.19332023
|
||||
- **HDP-P (physical AI):** https://doi.org/10.5281/ZENODO.19332440
|
||||
- **Helixar:** https://helixar.ai
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -0,0 +1,980 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "pn1797sn9Jb_"
|
||||
},
|
||||
"source": [
|
||||
"##### Copyright 2025 Google LLC."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "uivh5PY69ISg"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "O83CmJ2j9L3n"
|
||||
},
|
||||
"source": [
|
||||
"# Fine-Tune Gemma for Vision Tasks using Hugging Face Transformers and QLoRA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "f9673bd6"
|
||||
},
|
||||
"source": [
|
||||
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/docs/core/huggingface_vision_finetune_qlora.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/google-gemma/cookbook/blob/main/docs/core/huggingface_vision_finetune_qlora.ipynb\"><img src=\"https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png\" height=\"32\" width=\"70\"/>Run in Kaggle</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fgoogle-gemma%2Fcookbook%2Fmain%2Fdocs%2Fcore%2Fhuggingface_vision_finetune_qlora.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://github.com/google-gemma/cookbook/blob/main/docs/core/huggingface_vision_finetune_qlora.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e624ec07"
|
||||
},
|
||||
"source": [
|
||||
"This guide walks you through how to fine-tune Gemma on a custom image and text dataset for a vision task (generating product descriptions) using Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) and [TRL](https://huggingface.co/docs/trl/index). You will learn:\n",
|
||||
"\n",
|
||||
"- What is Quantized Low-Rank Adaptation (QLoRA)\n",
|
||||
"- Setup development environment\n",
|
||||
"- Create and prepare the fine-tuning dataset\n",
|
||||
"- Fine-tune Gemma using TRL and the SFTTrainer\n",
|
||||
"- Test Model Inference and generate product descriptions from images and text.\n",
|
||||
"\n",
|
||||
"Note: This guide requires a GPU which support bfloat16 data type such as NVIDIA L4 or NVIDIA A100 and more than 16GB of memory.\n",
|
||||
"\n",
|
||||
"## What is Quantized Low-Rank Adaptation (QLoRA)\n",
|
||||
"\n",
|
||||
"This guide demonstrates the use of [Quantized Low-Rank Adaptation (QLoRA)](https://arxiv.org/abs/2305.14314), which emerged as a popular method to efficiently fine-tune LLMs as it reduces computational resource requirements while maintaining high performance. In QloRA, the pretrained model is quantized to 4-bit and the weights are frozen. Then trainable adapter layers (LoRA) are attached and only the adapter layers are trained. Afterwards, the adapter weights can be merged with the base model or kept as a separate adapter.\n",
|
||||
"\n",
|
||||
"## Setup development environment\n",
|
||||
"\n",
|
||||
"The first step is to install Hugging Face Libraries, including TRL, and datasets to fine-tune open model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ba51aa79"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install Pytorch & other libraries\n",
|
||||
"%pip install torch tensorboard torchvision\n",
|
||||
"\n",
|
||||
"# Install Transformers\n",
|
||||
"%pip install transformers\n",
|
||||
"\n",
|
||||
"# Install Hugging Face libraries\n",
|
||||
"%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf pillow sentencepiece\n",
|
||||
"\n",
|
||||
"# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100\n",
|
||||
"#%pip install flash-attn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7ef3d54b"
|
||||
},
|
||||
"source": [
|
||||
"_Note: If you are using a GPU with Ampere architecture (such as NVIDIA L4) or newer, you can use Flash attention. Flash Attention is a method that significantly speeds computations up and reduces memory usage from quadratic to linear in sequence length, leading to acelerating training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main)._\n",
|
||||
"\n",
|
||||
"You need a valid Hugging Face Token to publish your model. If you are running inside a Google Colab, you can securely use your Hugging Face Token using the Colab secrets otherwise you can set the token as directly in the `login` method. Make sure your token has write access too, as you push your model to the Hub during training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "b6d79c93"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Login into Hugging Face Hub\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "42c60525"
|
||||
},
|
||||
"source": [
|
||||
"## Create and prepare the fine-tuning dataset\n",
|
||||
"\n",
|
||||
"When fine-tuning LLMs, it is important to know your use case and the task you want to solve. This helps you create a dataset to fine-tune your model. If you haven't defined your use case yet, you might want to go back to the drawing board.\n",
|
||||
"\n",
|
||||
"As an example, this guide focuses on the following use case:\n",
|
||||
"\n",
|
||||
"- Fine-tuning a Gemma model to generate concise, SEO-optimized product descriptions for an ecommerce platform, specifically tailored for mobile search.\n",
|
||||
"\n",
|
||||
"This guide uses the [philschmid/amazon-product-descriptions-vlm](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm) dataset, a dataset of Amazon product descriptions, including product images and categories.\n",
|
||||
"\n",
|
||||
"Hugging Face TRL supports multimodal conversations. The important piece is the \"image\" role, which tells the processing class that it should load the image. The structure should follow:\n",
|
||||
"\n",
|
||||
"```json\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "c4ecf6db"
|
||||
},
|
||||
"source": [
|
||||
"You can now use the Hugging Face Datasets library to load the dataset and create a prompt template to combine the image, product name, and category, and add a system message. The dataset includes images as`Pil.Image` objects."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "40c3a2cf"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8d1259be3dfa4b1e899c97026276ee41",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"README.md: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a5554c0595144c949b578eb1cbdfd0fd",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00000-of-00001.parquet: 0%| | 0.00/47.6M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "9ed0567e2e4e40a88c7eddfe7d6a6e2f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating train split: 0%| | 0/1345 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'role': 'system', 'content': 'You are an expert product description writer for Amazon.'}, {'role': 'user', 'content': [{'type': 'text', 'text': \"Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\\nOnly return description. The description should be SEO optimized and for a better mobile search experience.\\n\\n<PRODUCT>\\nRazor Agitator BMX/Freestyle Bike, 20-Inch\\n</PRODUCT>\\n\\n<CATEGORY>\\nSports & Outdoors | Outdoor Recreation | Cycling | Kids' Bikes & Accessories | Kids' Bikes\\n</CATEGORY>\\n\"}, {'type': 'image', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x413 at 0x7B7250181790>}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'Conquer the streets with the Razor Agitator BMX Bike! This 20-inch freestyle bike is built for young riders ready to take on any challenge. Durable frame, responsive handling – perfect for tricks and cruising. Get yours today!'}]}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"# System message for the assistant\n",
|
||||
"system_message = \"You are an expert product description writer for Amazon.\"\n",
|
||||
"\n",
|
||||
"# User prompt that combines the user query and the schema\n",
|
||||
"user_prompt = \"\"\"Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\n",
|
||||
"Only return description. The description should be SEO optimized and for a better mobile search experience.\n",
|
||||
"\n",
|
||||
"<PRODUCT>\n",
|
||||
"{product}\n",
|
||||
"</PRODUCT>\n",
|
||||
"\n",
|
||||
"<CATEGORY>\n",
|
||||
"{category}\n",
|
||||
"</CATEGORY>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# Convert dataset to OAI messages\n",
|
||||
"def format_data(sample):\n",
|
||||
" return {\n",
|
||||
" \"messages\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" #\"content\": [{\"type\": \"text\", \"text\": system_message}],\n",
|
||||
" \"content\": system_message,\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": user_prompt.format(\n",
|
||||
" product=sample[\"Product Name\"],\n",
|
||||
" category=sample[\"Category\"],\n",
|
||||
" ),\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": sample[\"image\"],\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": [{\"type\": \"text\", \"text\": sample[\"description\"]}],\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"def process_vision_info(messages: list[dict]) -> list[Image.Image]:\n",
|
||||
" image_inputs = []\n",
|
||||
" # Iterate through each conversation\n",
|
||||
" for msg in messages:\n",
|
||||
" # Get content (ensure it's a list)\n",
|
||||
" content = msg.get(\"content\", [])\n",
|
||||
" if not isinstance(content, list):\n",
|
||||
" content = [content]\n",
|
||||
"\n",
|
||||
" # Check each content element for images\n",
|
||||
" for element in content:\n",
|
||||
" if isinstance(element, dict) and (\n",
|
||||
" \"image\" in element or element.get(\"type\") == \"image\"\n",
|
||||
" ):\n",
|
||||
" # Get the image and convert to RGB\n",
|
||||
" if \"image\" in element:\n",
|
||||
" image = element[\"image\"]\n",
|
||||
" else:\n",
|
||||
" image = element\n",
|
||||
" image_inputs.append(image.convert(\"RGB\"))\n",
|
||||
" return image_inputs\n",
|
||||
"\n",
|
||||
"# Load dataset from the hub\n",
|
||||
"dataset = load_dataset(\"philschmid/amazon-product-descriptions-vlm\", split=\"train\")\n",
|
||||
"dataset = dataset.train_test_split(test_size=0.1)\n",
|
||||
"\n",
|
||||
"# Convert dataset to OAI messages\n",
|
||||
"# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\n",
|
||||
"dataset_train = [format_data(sample) for sample in dataset[\"train\"]]\n",
|
||||
"dataset_test = [format_data(sample) for sample in dataset[\"test\"]]\n",
|
||||
"\n",
|
||||
"print(dataset_train[345][\"messages\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "c0eb2e06"
|
||||
},
|
||||
"source": [
|
||||
"## Fine-tune Gemma using TRL and the SFTTrainer\n",
|
||||
"\n",
|
||||
"You are now ready to fine-tune your model. Hugging Face TRL [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) makes it straightforward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additional quality of life features, including:\n",
|
||||
"\n",
|
||||
"* Dataset formatting, including conversational and instruction formats\n",
|
||||
"* Training on completions only, ignoring prompts\n",
|
||||
"* Packing datasets for more efficient training\n",
|
||||
"* Parameter-efficient fine-tuning (PEFT) support including QloRA\n",
|
||||
"* Preparing the model and tokenizer for conversational fine-tuning (such as adding special tokens)\n",
|
||||
"\n",
|
||||
"The following code loads the Gemma model and tokenizer from Hugging Face and initializes the quantization configuration.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "18069ed2"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "42e58727637d4495ad8c5f753c5bcd06",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b11ec04ab48043b9937cfa3822b4fa42",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model.safetensors: 0%| | 0.00/10.2G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7659ae83140247efacee26159ca363b6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/2011 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "28c7b23ad9ba4316a8c95992884ad1d7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"generation_config.json: 0%| | 0.00/149 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "44b08b5b2cad4385893e29d5240a98d7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"processor_config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6eec3330ff144b3c9ad863cc89ed5709",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"chat_template.jinja: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "47f8fdc1492e4bb9b8d8fe9535c97d2c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ff3072e44aec41b6a0f6a28aeba99c4d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3d3e0871ad0e4642a5e2ca6f4baeebe4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer.json: 0%| | 0.00/32.2M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig\n",
|
||||
"\n",
|
||||
"# Hugging Face model id\n",
|
||||
"model_id = \"google/gemma-4-E2B\" # @param [\"google/gemma-4-E2B\",\"google/gemma-4-E4B\"] {\"allow-input\":true}\n",
|
||||
"\n",
|
||||
"# Check if GPU benefits from bfloat16\n",
|
||||
"if torch.cuda.get_device_capability()[0] < 8:\n",
|
||||
" raise ValueError(\"GPU does not support bfloat16, please use a GPU that supports bfloat16.\")\n",
|
||||
"\n",
|
||||
"# Define model init arguments\n",
|
||||
"model_kwargs = dict(\n",
|
||||
" dtype=torch.bfloat16, # What torch dtype to use, defaults to auto\n",
|
||||
" device_map=\"auto\", # Let torch decide how to load the model\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# BitsAndBytesConfig int-4 config\n",
|
||||
"model_kwargs[\"quantization_config\"] = BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||
" bnb_4bit_compute_dtype=model_kwargs[\"dtype\"],\n",
|
||||
" bnb_4bit_quant_storage=model_kwargs[\"dtype\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load model and tokenizer\n",
|
||||
"model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)\n",
|
||||
"processor = AutoProcessor.from_pretrained(\"google/gemma-4-E2B-it\") # Load the Instruction Tokenizer to use the official Gemma template"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "37ec1d1b"
|
||||
},
|
||||
"source": [
|
||||
"The `SFTTrainer` supports a built-in integration with `peft`, which makes it straightforward to efficiently tune LLMs using QLoRA. You only need to create a `LoraConfig` and provide it to the trainer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ed00e846"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from peft import LoraConfig\n",
|
||||
"\n",
|
||||
"peft_config = LoraConfig(\n",
|
||||
" lora_alpha=16,\n",
|
||||
" lora_dropout=0.05,\n",
|
||||
" r=16,\n",
|
||||
" bias=\"none\",\n",
|
||||
" target_modules=\"all-linear\",\n",
|
||||
" task_type=\"CAUSAL_LM\",\n",
|
||||
" modules_to_save=[\"lm_head\", \"embed_tokens\"], # make sure to save the lm_head and embed_tokens as you train the special tokens\n",
|
||||
" ensure_weight_tying=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bbd9fc1b"
|
||||
},
|
||||
"source": [
|
||||
"Before you can start your training, you need to define the hyperparameter you want to use in a `SFTConfig` and a custom `collate_fn` to handle the vision processing. The `collate_fn` converts the messages with text and images into a format that the model can understand.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "989be3c1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import SFTConfig\n",
|
||||
"\n",
|
||||
"args = SFTConfig(\n",
|
||||
" output_dir=\"gemma-product-description\", # directory to save and repository id\n",
|
||||
" num_train_epochs=3, # number of training epochs\n",
|
||||
" per_device_train_batch_size=1, # batch size per device during training\n",
|
||||
" optim=\"adamw_torch_fused\", # use fused adamw optimizer\n",
|
||||
" logging_steps=5, # log every 5 steps\n",
|
||||
" save_strategy=\"epoch\", # save checkpoint every epoch\n",
|
||||
" eval_strategy=\"epoch\", # evaluate checkpoint every epoch\n",
|
||||
" learning_rate=2e-4, # learning rate, based on QLoRA paper\n",
|
||||
" bf16=True, # use bfloat16 precision\n",
|
||||
" max_grad_norm=0.3, # max gradient norm based on QLoRA paper\n",
|
||||
" lr_scheduler_type=\"constant\", # use constant learning rate scheduler\n",
|
||||
" push_to_hub=True, # push model to hub\n",
|
||||
" report_to=\"tensorboard\", # report metrics to tensorboard\n",
|
||||
" dataset_text_field=\"\", # need a dummy field for collator\n",
|
||||
" dataset_kwargs={\"skip_prepare_dataset\": True}, # important for collator\n",
|
||||
" remove_unused_columns = False # important for collator\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create a data collator to encode text and image pairs\n",
|
||||
"def collate_fn(examples):\n",
|
||||
" texts = []\n",
|
||||
" images = []\n",
|
||||
" for example in examples:\n",
|
||||
" image_inputs = process_vision_info(example[\"messages\"])\n",
|
||||
" text = processor.apply_chat_template(\n",
|
||||
" example[\"messages\"], add_generation_prompt=False, tokenize=False\n",
|
||||
" )\n",
|
||||
" texts.append(text.strip())\n",
|
||||
" images.append(image_inputs)\n",
|
||||
"\n",
|
||||
" # Tokenize the texts and process the images\n",
|
||||
" batch = processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n",
|
||||
"\n",
|
||||
" # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation\n",
|
||||
" labels = batch[\"input_ids\"].clone()\n",
|
||||
"\n",
|
||||
" # Mask tokens for not being used in the loss computation\n",
|
||||
" labels[labels == processor.tokenizer.pad_token_id] = -100\n",
|
||||
" labels[labels == processor.tokenizer.boi_token_id] = -100\n",
|
||||
" labels[labels == processor.tokenizer.image_token_id] = -100\n",
|
||||
" labels[labels == processor.tokenizer.eoi_token_id] = -100\n",
|
||||
"\n",
|
||||
" batch[\"labels\"] = labels\n",
|
||||
" return batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dd88e798"
|
||||
},
|
||||
"source": [
|
||||
"You now have every building block you need to create your `SFTTrainer` to start the training of your model.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ade95df7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import SFTTrainer\n",
|
||||
"\n",
|
||||
"# Create Trainer object\n",
|
||||
"trainer = SFTTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=args,\n",
|
||||
" train_dataset=dataset_train,\n",
|
||||
" eval_dataset=dataset_test,\n",
|
||||
" peft_config=peft_config,\n",
|
||||
" processing_class=processor,\n",
|
||||
" data_collator=collate_fn,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "fad61a6a"
|
||||
},
|
||||
"source": [
|
||||
"Start training by calling the `train()` method."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "995e7e38"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='456' max='456' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [456/456 11:20, Epoch 3/3]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Epoch</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" <th>Validation Loss</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>1.326710</td>\n",
|
||||
" <td>1.441816</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>1.042711</td>\n",
|
||||
" <td>1.320613</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>0.739179</td>\n",
|
||||
" <td>1.458798</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start training, the model will be automatically saved to the Hub and the output directory\n",
|
||||
"trainer.train()\n",
|
||||
"\n",
|
||||
"# Save the final model again to the Hugging Face Hub\n",
|
||||
"trainer.save_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b47b9733"
|
||||
},
|
||||
"source": [
|
||||
"Before you can test your model, make sure to free the memory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "40a32ed7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# free the memory again\n",
|
||||
"del model\n",
|
||||
"del trainer\n",
|
||||
"torch.cuda.empty_cache()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "862e9728"
|
||||
},
|
||||
"source": [
|
||||
"When using QLoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with serving stacks like vLLM or TGI, you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This saves a default model, which can be used for inference.\n",
|
||||
"\n",
|
||||
"Note: It requires more than 30GB of CPU Memory when you want to merge the adapter into the model. You can skip this and continue with Test Model Inference.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "761e324b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "20d63c526a854f2a880882c246ac3b3d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/2011 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8703db4619fd4f8eb66bf0cc2211dc7e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Writing model shards: 0%| | 0/5 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['merged_model/processor_config.json']"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from peft import PeftModel\n",
|
||||
"\n",
|
||||
"# Load Model base model\n",
|
||||
"model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)\n",
|
||||
"\n",
|
||||
"# Merge LoRA and base model and save\n",
|
||||
"peft_model = PeftModel.from_pretrained(model, args.output_dir)\n",
|
||||
"merged_model = peft_model.merge_and_unload()\n",
|
||||
"merged_model.save_pretrained(\"merged_model\", safe_serialization=True, max_shard_size=\"2GB\")\n",
|
||||
"\n",
|
||||
"processor = AutoProcessor.from_pretrained(args.output_dir)\n",
|
||||
"processor.save_pretrained(\"merged_model\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bf86e31d"
|
||||
},
|
||||
"source": [
|
||||
"## Test Model Inference and generate product descriptions\n",
|
||||
"\n",
|
||||
"After the training is done, you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.\n",
|
||||
"\n",
|
||||
"Note: Evaluating Generative AI models is not a trivial task since one input can have multiple correct outputs. This guide only focuses on manual evaluation and vibe checks.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "aab1c5c5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "89b0d1d25dba4e8e8642c41e69c4c65e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/2012 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_id = \"merged_model\"\n",
|
||||
"\n",
|
||||
"# Load Model with PEFT adapter\n",
|
||||
"model = AutoModelForImageTextToText.from_pretrained(\n",
|
||||
" model_id,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" dtype=\"auto\",\n",
|
||||
")\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "3dccb57c"
|
||||
},
|
||||
"source": [
|
||||
"You can test inference by providing a product name, category and image. The `sample` includes a marvel action figure.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1fd887f4"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<bos><|turn>system\n",
|
||||
"You are an expert product description writer for Amazon.<turn|>\n",
|
||||
"<|turn>user\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"<|image|>\n",
|
||||
"\n",
|
||||
"Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\n",
|
||||
"Only return description. The description should be SEO optimized and for a better mobile search experience.\n",
|
||||
"\n",
|
||||
"<PRODUCT>\n",
|
||||
"Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur\n",
|
||||
"</PRODUCT>\n",
|
||||
"\n",
|
||||
"<CATEGORY>\n",
|
||||
"Toys & Games | Toy Figures & Playsets | Action Figures\n",
|
||||
"</CATEGORY><turn|>\n",
|
||||
"<|turn>model\n",
|
||||
"\n",
|
||||
"MODEL OUTPUT>> \n",
|
||||
"\n",
|
||||
"Enhance your collection with the Marvel Avengers - Avengers Assemble Ultron-Comforter Set! This soft and cuddly blanket and pillowcase feature everyone's favorite Avengers, Iron Man, and his loyal companion War Machine. Officially licensed by Marvel. Bring home the heroic team!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"# Test sample with Product Name, Category and Image\n",
|
||||
"sample = {\n",
|
||||
" \"product_name\": \"Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur\",\n",
|
||||
" \"category\": \"Toys & Games | Toy Figures & Playsets | Action Figures\",\n",
|
||||
" \"image\": Image.open(requests.get(\"https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg\", stream=True).raw).convert(\"RGB\")\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"def generate_description(sample, model, processor):\n",
|
||||
" # Convert sample into messages and then apply the chat template\n",
|
||||
" messages = [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\",\"image\": sample[\"image\"]},\n",
|
||||
" {\"type\": \"text\", \"text\": user_prompt.format(product=sample[\"product_name\"], category=sample[\"category\"])},\n",
|
||||
" ]},\n",
|
||||
" ]\n",
|
||||
" text = processor.apply_chat_template(\n",
|
||||
" messages, tokenize=False, add_generation_prompt=True\n",
|
||||
" )\n",
|
||||
" print(text)\n",
|
||||
" # Process the image and text\n",
|
||||
" image_inputs = process_vision_info(messages)\n",
|
||||
" # Tokenize the text and process the images\n",
|
||||
" inputs = processor(\n",
|
||||
" text=[text],\n",
|
||||
" images=image_inputs,\n",
|
||||
" padding=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" )\n",
|
||||
" # Move the inputs to the device\n",
|
||||
" inputs = inputs.to(model.device)\n",
|
||||
"\n",
|
||||
" # Generate the output\n",
|
||||
" stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids(\"<turn|>\")]\n",
|
||||
" generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)\n",
|
||||
" # Trim the generation and decode the output to text\n",
|
||||
" generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]\n",
|
||||
" output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
" )\n",
|
||||
" return output_text[0]\n",
|
||||
"\n",
|
||||
"# generate the description\n",
|
||||
"description = generate_description(sample, model, processor)\n",
|
||||
"print(\"MODEL OUTPUT>> \\n\")\n",
|
||||
"print(description)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "6f8ff452"
|
||||
},
|
||||
"source": [
|
||||
"## Summary and next steps\n",
|
||||
"\n",
|
||||
"This tutorial covered how to fine-tune a Gemma model for vision tasks using TRL and QLoRA, specifically for generating product descriptions. Check out the following docs next:\n",
|
||||
"\n",
|
||||
"* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
|
||||
"* Learn how to [fine-tune Gemma for text tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora).\n",
|
||||
"* Learn how to [full model fine-tune using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune).\n",
|
||||
"* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
|
||||
"* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
|
||||
"* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"name": "huggingface_vision_finetune_qlora.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,789 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "G3MMAcssHTML"
|
||||
},
|
||||
"source": [
|
||||
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
|
||||
"<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Tce3stUlHN0L"
|
||||
},
|
||||
"source": [
|
||||
"##### Copyright 2025 Google LLC."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "tuOe1ymfHZPu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SDEExiAk4fLb"
|
||||
},
|
||||
"source": [
|
||||
"# Fine-tune Gemma in Keras using LoRA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ZFWzQEqNosrS"
|
||||
},
|
||||
"source": [
|
||||
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/lora_tuning\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/docs/core/lora_tuning.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/google-gemini/gemma-cookbook/blob/main/docs/core/lora_tuning.ipynb\"><img src=\"https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png\" height=\"32\" width=\"70\"/>Run in Kaggle</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fgoogle-gemini%2Fgemma-cookbook%2Fmain%2Fdocs%2Fcore%2Flora_tuning.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://github.com/google-gemini/gemma-cookbook/blob/main/docs/core/lora_tuning.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "lSGRSsRPgkzK"
|
||||
},
|
||||
"source": [
|
||||
"Generative artificial intelligent (AI) models like Gemma are effective at a variety of tasks. You can further fine-tune Gemma models with domain-specific data to perform tasks such as sentiment analysis. However, full fine-tuning of generative models by updating billions of parameters is resource intensive, requiring specialized hardware, such as GPUs, processing time, and memory to load the model parameters.\n",
|
||||
"\n",
|
||||
"[Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This technique makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs. This tutorial walks you through using Keras to perform LoRA fine-tuning on a Gemma model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "lyhHCMfoRZ_v"
|
||||
},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
|
||||
"\n",
|
||||
"* Get access to Gemma on [kaggle.com](https://kaggle.com).\n",
|
||||
"* Select a Colab runtime with sufficient resources to tune\n",
|
||||
" the Gemma model you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n",
|
||||
"* Generate and configure a Kaggle username and API key.\n",
|
||||
"\n",
|
||||
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AZ5Qo0fxRZ1V"
|
||||
},
|
||||
"source": [
|
||||
"### Select a Colab runtime\n",
|
||||
"\n",
|
||||
"To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
|
||||
"\n",
|
||||
"1. In the upper-right of the Colab window, select ▾ (**Additional connection options**).\n",
|
||||
"2. Select **Change runtime type**.\n",
|
||||
"3. Under **Hardware accelerator**, select **T4 GPU**."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hsPC0HRkJl0K"
|
||||
},
|
||||
"source": [
|
||||
"### Configure your API key\n",
|
||||
"\n",
|
||||
"To use Gemma, you must provide your Kaggle username and a Kaggle API key.\n",
|
||||
"\n",
|
||||
"To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
|
||||
"\n",
|
||||
"In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7iOF6Yo-wUEC"
|
||||
},
|
||||
"source": [
|
||||
"### Set environment variables\n",
|
||||
"\n",
|
||||
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0_EdOg9DPK6Q"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
|
||||
"# vars as appropriate for your system.\n",
|
||||
"\n",
|
||||
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
|
||||
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "CuEUAKJW1QkQ"
|
||||
},
|
||||
"source": [
|
||||
"### Install Keras packages\n",
|
||||
"\n",
|
||||
"Install the Keras and KerasHub Python packages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1eeBtYqJsZPG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q -U keras-hub\n",
|
||||
"!pip install -q -U keras"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "rGLS-l5TxIR4"
|
||||
},
|
||||
"source": [
|
||||
"### Select a backend\n",
|
||||
"\n",
|
||||
"Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch. For this tutorial, configure the backend for JAX as it typically provides the better performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "yn5uy8X8sdD0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"torch\" or \"tensorflow\".\n",
|
||||
"# Avoid memory fragmentation on JAX backend.\n",
|
||||
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hZs8XXqUKRmi"
|
||||
},
|
||||
"source": [
|
||||
"### Import packages\n",
|
||||
"\n",
|
||||
"Import the Python packages needed for this tutorial, including Keras and KerasHub."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "FYHyPUA9hKTf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import keras\n",
|
||||
"import keras_hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7RCE3fdGhDE5"
|
||||
},
|
||||
"source": [
|
||||
"## Load model\n",
|
||||
"\n",
|
||||
"Keras provides implementations of Gemma and many other popular [model architectures](https://keras.io/keras_hub/api/models/). Use the `Gemma3CausalLM.from_preset()` method to configure an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "vz5zLEyLstfn"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\"gemma3_instruct_1b\")\n",
|
||||
"gemma_lm.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Nl4lvPy5zA26"
|
||||
},
|
||||
"source": [
|
||||
"The `Gemma3CausalLM.from_preset()` method instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma#_xxxxxxx\"` specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/keras/gemma3)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "G_L6A5J-1QgC"
|
||||
},
|
||||
"source": [
|
||||
"## Inference before fine tuning\n",
|
||||
"\n",
|
||||
"Once you have downloaded and configured a Gemma model, you can query it with various prompts to see how it responds."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PVLXadptyo34"
|
||||
},
|
||||
"source": [
|
||||
"### Europe trip prompt\n",
|
||||
"\n",
|
||||
"Query the model for suggestions on what to do on a trip to Europe."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZwQz3xxxKciD"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"What should I do on a trip to Europe?\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"The first thing to know is that you will have a great time!\n",
|
||||
"\n",
|
||||
"Europe is a great place for a vacation. The countries of Europe are all very different and offer a wide range of activities and attractions. The countries of Europe are also very close to each other, which means you can visit many different places within a short time.\n",
|
||||
"\n",
|
||||
"The best way to plan a trip to Europe is to look up the countries you want to visit and see what activities are offered in each country. You can also look for tours and tours that offer a good value for money.\n",
|
||||
"\n",
|
||||
"You can also look for hotels and flights that offer good deals. If you are looking for a good value for money, you should look for hotels and flights that offer good deals. This means you will have a great time on your trip!\n",
|
||||
"\n",
|
||||
"The next step is to book your tickets to the countries you want to visit. If you are planning to visit many countries, it's a good idea to book your tickets early. This means you’ll be able to get the best deal and avoid the long queues.\n",
|
||||
"\n",
|
||||
"The next step is to plan your itinerary. You can use a travel guide to plan your itinerary\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n",
|
||||
"\n",
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"What should I do on a trip to Europe?\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
|
||||
"gemma_lm.compile(sampler=sampler)\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AePQUIs2h-Ks"
|
||||
},
|
||||
"source": [
|
||||
"The model responds with generic tips on how to plan a trip."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YQ74Zz_S0iVv"
|
||||
},
|
||||
"source": [
|
||||
"### Photosynthesis prompt\n",
|
||||
"\n",
|
||||
"Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "lorJMbsusgoo"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"Explain the process of photosynthesis in a way that a child could understand.\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"Photosynthesis is a biological process that occurs in plants, algae, and some other organisms. In the process, light energy is captured and converted into the energy stored in the bonds of organic molecules. The process is crucial for life on Earth because it enables plants to use carbon dioxide and water to produce glucose and oxygen, which are essential for all living things.\n",
|
||||
"The process involves several stages:\n",
|
||||
"1. Light Reactions: Light energy is absorbed by pigments in the chloroplasts of the plant, converting it into chemical energy in the form of ATP and reducing power.\n",
|
||||
"2. Carbon Fixation: During this stage, carbon dioxide is combined with hydrogen to form organic molecules such as starch or glucose, which are used as a source of energy.\n",
|
||||
"3. Calvin Cycle: The process of carbon fixation occurs in the stroma of the chloroplasts. It involves the capture and reduction of carbon dioxide, producing glucose and reducing power in the form of ATP and NADPH molecules.\n",
|
||||
"4. Stroma: The stroma is the fluid-filled space where the light reactions occur in the chloroplasts.\n",
|
||||
"5. Chloroplasts: The chloroplasts contain the green pigments that absorb\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WBQieduRizZf"
|
||||
},
|
||||
"source": [
|
||||
"The model response contains words that might not be easy to understand for a child such as chlorophyll."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Pt7Nr6a7tItO"
|
||||
},
|
||||
"source": [
|
||||
"## LoRA fine-tuning\n",
|
||||
"\n",
|
||||
"This section shows you how to do fine-tuning using the Low Rank Adaptation (LoRA) tuning technique. This approach allows you to change the behavior of Gemma models using fewer compute resources."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9T7xe_jzslv4"
|
||||
},
|
||||
"source": [
|
||||
"### Load dataset\n",
|
||||
"\n",
|
||||
"Prepare a dataset for tuning by downloading an existing data set and formatting if for use with the the Keras `fit()` fine-tuning method. This tutorial uses the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k) for fine-tuning. The dataset contains 15,000 high-quality human-generated prompt and response pairs specifically designed for tuning generative models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "xRaNCPUXKoa7"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2025-04-10 20:48:49-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n",
|
||||
"Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.114, 3.163.189.74, ...\n",
|
||||
"Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.\n",
|
||||
"HTTP request sent, awaiting response... 302 Found\n",
|
||||
"Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
|
||||
"--2025-04-10 20:48:49-- https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE\n",
|
||||
"Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.238.217.63, 18.238.217.81, 18.238.217.120, ...\n",
|
||||
"Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.238.217.63|:443... connected.\n",
|
||||
"HTTP request sent, awaiting response... 200 OK\n",
|
||||
"Length: 13085339 (12M) [text/plain]\n",
|
||||
"Saving to: ‘databricks-dolly-15k.jsonl’\n",
|
||||
"\n",
|
||||
"databricks-dolly-15 100%[===================>] 12.48M --.-KB/s in 0.08s \n",
|
||||
"\n",
|
||||
"2025-04-10 20:48:49 (156 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "45UpBDfBgf0I"
|
||||
},
|
||||
"source": [
|
||||
"### Format tuning data\n",
|
||||
"\n",
|
||||
"Format the downloaded data for use with the Keras `fit()` method. The following code extracts a subset of the training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZiS-KU9osh_N"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"prompts = []\n",
|
||||
"responses = []\n",
|
||||
"line_count = 0\n",
|
||||
"\n",
|
||||
"with open(\"databricks-dolly-15k.jsonl\") as file:\n",
|
||||
" for line in file:\n",
|
||||
" if line_count >= 1000:\n",
|
||||
" break # Limit the training examples, to reduce execution time.\n",
|
||||
"\n",
|
||||
" examples = json.loads(line)\n",
|
||||
" # Filter out examples with context, to keep it simple.\n",
|
||||
" if examples[\"context\"]:\n",
|
||||
" continue\n",
|
||||
" # Format data into prompts and response lists.\n",
|
||||
" prompts.append(examples[\"instruction\"])\n",
|
||||
" responses.append(examples[\"response\"])\n",
|
||||
"\n",
|
||||
" line_count += 1\n",
|
||||
"\n",
|
||||
"data = {\n",
|
||||
" \"prompts\": prompts,\n",
|
||||
" \"responses\": responses\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "cBLW5hiGj31i"
|
||||
},
|
||||
"source": [
|
||||
"### Configure LoRA tuning\n",
|
||||
"\n",
|
||||
"Activate LoRA tuning using the Keras `model.backbone.enable_lora()` method, including a LoRA rank value. The *LoRA rank* determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments. A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n",
|
||||
"\n",
|
||||
"This example uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This setting is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "RCucu6oHz53G"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Enable LoRA for the model and set the LoRA rank to 4.\n",
|
||||
"gemma_lm.backbone.enable_lora(rank=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PlMLp_NVbRoQ"
|
||||
},
|
||||
"source": [
|
||||
"Check the model summary after setting the LoRA rank. Notice that enabling LoRA reduces the number of trainable parameters significantly compared to the total number of parameters in the model:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "KqYyS0gm6pNy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gemma_lm.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hQQ47kcdpbZ9"
|
||||
},
|
||||
"source": [
|
||||
"Configure the rest of the fine-tuning settings, including the preprocessor settings, optimizer, number of tuning epochs, and batch size:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "p9sBNH8SAjgB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Limit the input sequence length to 256 (to control memory usage).\n",
|
||||
"gemma_lm.preprocessor.sequence_length = 256\n",
|
||||
"# Use AdamW (a common optimizer for transformer models).\n",
|
||||
"optimizer = keras.optimizers.AdamW(\n",
|
||||
" learning_rate=5e-5,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
")\n",
|
||||
"# Exclude layernorm and bias terms from decay.\n",
|
||||
"optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
|
||||
"\n",
|
||||
"gemma_lm.compile(\n",
|
||||
" loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
|
||||
" optimizer=optimizer,\n",
|
||||
" weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OA0ozGC66tk1"
|
||||
},
|
||||
"source": [
|
||||
"### Run the fine-tune process\n",
|
||||
"\n",
|
||||
"Run the fine-tuning process using the `fit()` method. This process can take several minutes depending on your compute resources, data size, and number of epochs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "_Peq7TnLtHse"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m923s\u001b[0m 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<keras.src.callbacks.history.History at 0x799d04393c40>"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"gemma_lm.fit(data, epochs=1, batch_size=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bx3m8f1dB7nk"
|
||||
},
|
||||
"source": [
|
||||
"#### Mixed precision fine-tuning on NVIDIA GPUs\n",
|
||||
"\n",
|
||||
"Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "T0lHxEDX03gp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Uncomment the line below if you want to enable mixed precision training on GPUs\n",
|
||||
"# keras.mixed_precision.set_global_policy('mixed_bfloat16')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4yd-1cNw1dTn"
|
||||
},
|
||||
"source": [
|
||||
"## Inference after fine-tuning\n",
|
||||
"\n",
|
||||
"After fine-tuning, you should see changes in the responses when the tuned model is given the same prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "H55JYJ1a1Kos"
|
||||
},
|
||||
"source": [
|
||||
"### Europe trip prompt\n",
|
||||
"\n",
|
||||
"Try the Europe trip prompt from earlier and note the differences in the response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Y7cDJHy8WfCB"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"What should I do on a trip to Europe?\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"What should I do on a trip to Europe?\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
|
||||
"gemma_lm.compile(sampler=sampler)\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OXP6gg2mjs6u"
|
||||
},
|
||||
"source": [
|
||||
"The model now provides a shorter response to a question about visiting Europe."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "H7nVd8Mi1Yta"
|
||||
},
|
||||
"source": [
|
||||
"### Photosynthesis prompt\n",
|
||||
"\n",
|
||||
"Try the photosynthesis explanation prompt from earlier and note the differences in the response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "X-2sYl2jqwl7"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"Explain the process of photosynthesis in a way that a child could understand.\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PCmAmqrvkEhc"
|
||||
},
|
||||
"source": [
|
||||
"The model now explains photosynthesis in simpler terms."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "I8kFG12l0mVe"
|
||||
},
|
||||
"source": [
|
||||
"## Improving fine-tune results\n",
|
||||
"\n",
|
||||
"For demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:\n",
|
||||
"\n",
|
||||
"1. Increasing the size of the fine-tuning dataset\n",
|
||||
"2. Training for more steps (epochs)\n",
|
||||
"3. Setting a higher LoRA rank\n",
|
||||
"4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gSsRdeiof_rJ"
|
||||
},
|
||||
"source": [
|
||||
"## Summary and next steps\n",
|
||||
"\n",
|
||||
"This tutorial covered LoRA fine-tuning on a Gemma model using Keras. Check out the following docs next:\n",
|
||||
"\n",
|
||||
"* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
|
||||
"* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
|
||||
"* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
|
||||
"* Learn how to [fine-tune Gemma using Keras and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"name": "lora_tuning.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,595 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install -U transformers"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can load model into `AutoModelForMultimodalLM` to make use of all capabilities."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForMultimodalLM, AutoProcessor\n",
|
||||
"#model_list = [\"google/gemma-4-26B-A4B-it\", \"google/gemma-4-E4B-it\",\n",
|
||||
"# \"google/gemma-4-E2B-it\", \"google/gemma-4-31B-it\"]\n",
|
||||
"model_id = \"google/gemma-4-E2B-it\"\n",
|
||||
"model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map=\"auto\")\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_id)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Code completion"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We give Gemma-4 a website screenshot to reproduce the code."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"Write HTML code for this page.\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" enable_thinking=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=4000)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n",
|
||||
"\n",
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Video Inference"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"video\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in the video? What is the song about?\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" load_audio_from_video=True,\n",
|
||||
").to(model.device)\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=200)\n",
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Multimodal Function Calling"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"WEATHER_TOOL = {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"name\": \"get_weather\",\n",
|
||||
" \"description\": \"Gets the current weather for a specific location.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"city\": {\"type\": \"string\", \"description\": \"The city name\"},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"city\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"tools = [WEATHER_TOOL]\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is the city in this image? Check the weather there right now.\"},\n",
|
||||
" ]},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tools=[WEATHER_TOOL],\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" enable_thinking=True,\n",
|
||||
").to(model.device)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"output = model.generate(**inputs, max_new_tokens=1000)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Any-to-any inference"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also run the model with `any-to-any` pipeline."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"any-to-any\", model=\"google/gemma-4-e2b-it\")\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"video\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"pipe(messages)#, load_audio_from_video=True)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"video\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"inputs = inputs.to(model.device)\n",
|
||||
"\n",
|
||||
"generated_ids = model.generate(**inputs, max_new_tokens=128)\n",
|
||||
"generated_ids_trimmed = [\n",
|
||||
" out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
|
||||
"]\n",
|
||||
"output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
")\n",
|
||||
"print(output_text)\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Object detection and pointing"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import torch\n",
|
||||
"from transformers.image_utils import load_image\n",
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.patches as patches\n",
|
||||
"import json"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"image_url = \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\"\n",
|
||||
"image = load_image(image_url)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def resize_to_48_multiple(image):\n",
|
||||
" w, h = image.size\n",
|
||||
" new_w = (w // 48) * 48\n",
|
||||
" new_h = (h // 48) * 48\n",
|
||||
" return image.crop((0, 0, new_w, new_h))"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def inputs_for_object_detection(image, what_object):\n",
|
||||
" messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": image},\n",
|
||||
" {\"type\": \"text\", \"text\": f\"What's the bounding box for the {what_object} in the image?\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" enable_thinking=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return inputs.to(model.device)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def extract_json(text: str):\n",
|
||||
" text = text.strip()\n",
|
||||
"\n",
|
||||
" text = re.sub(r\"^```(?:json)?\\s*\", \"\", text)\n",
|
||||
" text = re.sub(r\"\\s*```$\", \"\", text)\n",
|
||||
"\n",
|
||||
" # Try direct parse first\n",
|
||||
" try:\n",
|
||||
" return json.loads(text)\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" # Fallback: extract first JSON object or array\n",
|
||||
" match = re.search(r'(\\{.*\\}|\\[.*\\])', text, re.DOTALL)\n",
|
||||
" if match:\n",
|
||||
" candidate = match.group(1)\n",
|
||||
" return json.loads(candidate)\n",
|
||||
"\n",
|
||||
" raise ValueError(\"No valid JSON found\")"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def detect_object(image_url, what_object):\n",
|
||||
" image = load_image(image_url)\n",
|
||||
" image = resize_to_48_multiple(image)\n",
|
||||
" inputs = inputs_for_object_detection(image, what_object)\n",
|
||||
" input_len = inputs[\"input_ids\"].shape[-1]\n",
|
||||
" generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)\n",
|
||||
" generated = processor.decode(generated_outputs[0, input_len:])\n",
|
||||
" parsed_json = extract_json(generated)[0]\n",
|
||||
" return parsed_json"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):\n",
|
||||
" dpi = 72\n",
|
||||
" width, height = image.size\n",
|
||||
" fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})\n",
|
||||
"\n",
|
||||
" ax.imshow(image)\n",
|
||||
"\n",
|
||||
" ymin, xmin, ymax, xmax = box\n",
|
||||
" re_h, re_w = resize_shape if resize_shape is not None else (height, width)\n",
|
||||
" xmin = (xmin / re_w) * width\n",
|
||||
" ymin = (ymin/ re_h) * height\n",
|
||||
" xmax = (xmax / re_w) * width\n",
|
||||
" ymax = (ymax/ re_h) * height\n",
|
||||
"\n",
|
||||
" w = xmax - xmin\n",
|
||||
" h = ymax - ymin\n",
|
||||
"\n",
|
||||
" rect = patches.Rectangle(\n",
|
||||
" (xmin, ymin),\n",
|
||||
" w,\n",
|
||||
" h,\n",
|
||||
" linewidth=10,\n",
|
||||
" edgecolor=\"green\",\n",
|
||||
" facecolor=\"none\"\n",
|
||||
" )\n",
|
||||
" ax.add_patch(rect)\n",
|
||||
"\n",
|
||||
" if label is not None:\n",
|
||||
" ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor=\"yellow\", alpha=0.5))\n",
|
||||
"\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
" plt.savefig(f\"boxes_{i}.png\")\n",
|
||||
" plt.close(fig)\n",
|
||||
" display(fig)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def display_detected_object(image_url, what_object):\n",
|
||||
" image = load_image(image_url)\n",
|
||||
" image = resize_to_48_multiple(image)\n",
|
||||
" detection = detect_object(image_url, what_object)\n",
|
||||
" box = detection[\"box_2d\"]\n",
|
||||
" label = detection.get(\"label\", f\"{what_object}\")\n",
|
||||
" draw_pascal_voc_boxes(\"1000\", image, box, label)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"display_detected_object(\"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\", \"bike\")"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"##\u00a0Captioning"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"Write single detailed caption for this image.\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=512)\n",
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n",
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Audio Understanding"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"audio\", \"url\": \"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"Can you describe this audio in detail?\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(\n",
|
||||
" **inputs,\n",
|
||||
" max_new_tokens=1000,\n",
|
||||
" do_sample=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(processor.decode(output[0], skip_special_tokens=True))\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
# Hugging Face Gemma Recipes
|
||||
|
||||

|
||||
|
||||
🤗💎 Welcome! This repository contains *minimal* recipes to get started quickly with the Gemma family of models.
|
||||
|
||||
> [!Note]
|
||||
> Gemma 4 Multimodal inference (vision, video, audio, function calling, object detection): <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
To quickly run a Gemma 💎 model on your machine, install the latest version of `timm` (for the vision encoder) and 🤗 `transformers` to run inference, or if you want to fine tune it.
|
||||
|
||||
```shell
|
||||
$ pip install -U -q transformers timm
|
||||
```
|
||||
|
||||
### Inference with pipeline
|
||||
|
||||
The easiest way to start using Gemma 3n is by using the pipeline abstraction in transformers:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(
|
||||
"image-text-to-text",
|
||||
model="google/gemma-3n-E4B-it", # "google/gemma-3n-E4B-it"
|
||||
device="cuda",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/ariG23498/demo-data/resolve/main/airplane.jpg"},
|
||||
{"type": "text", "text": "Describe this image"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
output = pipe(text=messages, max_new_tokens=32)
|
||||
print(output[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
### Detailed inference with transformers
|
||||
|
||||
Initialize the model and the processor from the Hub, and write the `model_generation` function that takes care of processing the prompts and running the inference on the model.
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
import torch
|
||||
|
||||
model_id = "google/gemma-3n-e4b-it" # google/gemma-3n-e2b-it
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_id).to(device)
|
||||
|
||||
def model_generation(model, messages):
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
inputs = inputs.to(model.device, dtype=model.dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
generation = model.generate(**inputs, max_new_tokens=32, disable_compile=False)
|
||||
generation = generation[:, input_len:]
|
||||
|
||||
decoded = processor.batch_decode(generation, skip_special_tokens=True)
|
||||
print(decoded[0])
|
||||
```
|
||||
|
||||
And then using calling it with our specific modality:
|
||||
|
||||
#### Text only
|
||||
|
||||
```python
|
||||
# Text Only
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the capital of France?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
model_generation(model, messages)
|
||||
```
|
||||
|
||||
#### Interleaved with Audio
|
||||
|
||||
```python
|
||||
# Interleaved with Audio
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe the following speech segment in English:"},
|
||||
{"type": "audio", "audio": "https://huggingface.co/datasets/ariG23498/demo-data/resolve/main/speech.wav"},
|
||||
]
|
||||
}
|
||||
]
|
||||
model_generation(model, messages)
|
||||
```
|
||||
|
||||
#### Interleaved with Image/Video
|
||||
|
||||
```python
|
||||
# Interleaved with Image
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/ariG23498/demo-data/resolve/main/airplane.jpg"},
|
||||
{"type": "text", "text": "Describe this image."}
|
||||
]
|
||||
}
|
||||
]
|
||||
model_generation(model, messages)
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
### Gemma 4
|
||||
|
||||
#### Notebooks
|
||||
|
||||
* [Multimodal inference with Gemma 4 (vision, video, audio, function calling, object detection)](/notebooks/Gemma4_(E2B)-Multimodal.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
||||
### Gemma 3n
|
||||
|
||||
#### Notebooks
|
||||
|
||||
* [Multimodal inference using Gemma 3n via pipeline](/notebooks/gemma3n_inference_via_pipeline.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/gemma3n_inference_via_pipeline.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
||||
## Function Calling
|
||||
|
||||
### Gemma 3n
|
||||
|
||||
#### Notebooks
|
||||
|
||||
* [Function Calling with Gemma 3n: Local File Reader](/notebooks/Gemma_3n_Function_Calling_document_summarizer.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma_3n_Function_Calling_document_summarizer.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
||||
## Fine Tuning
|
||||
|
||||
We include a series of notebook+scripts for fine tuning the models.
|
||||
|
||||
### Gemma 3n
|
||||
|
||||
#### Notebooks
|
||||
|
||||
* [Gemma 3n Conversational Fine tuning 2B on free Colab T4](/notebooks/fine_tune_gemma3n_on_t4.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
* [Gemma 3n Conversational Fine tuning 4B with Unsloth on free Colab T4](/notebooks/Gemma3N_(4B)-Conversational.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma3N_(4B)-Conversational.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
* [Gemma 3n Multimodal Fine tuning 2B/4B with Unsloth on free Colab T4](/notebooks/gemma3n_multimodal_finetuning_on_rocov2_radiology.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/gemma3n_multimodal_finetuning_on_rocov2_radiology.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
* [Fine tuning Gemma 3n on audio](/notebooks/fine_tune_gemma3n_on_audio.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_audio.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
* [Fine tuning Gemma 3n on GUI Grounding](/notebooks/Gemma_3n_GUI_Finetune.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma_3n_GUI_Finetune.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
* [Fine tuning Gemma3n on video+audio using FineVideo (all modalities)](/notebooks/Gemma3n_Fine_tuning_on_All_Modalities.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma3n_Fine_tuning_on_All_Modalities.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
||||
#### Scripts
|
||||
|
||||
* [Fine tuning Gemma 3n on images using TRL](/scripts/ft_gemma3n_image_trl.py)
|
||||
* [Fine tuning Gemma 3n on images (script)](/scripts/ft_gemma3n_image_vt.py)
|
||||
* [Fine tuning Gemma 3n on audio (script)](/scripts/ft_gemma3n_audio_vt.py)
|
||||
* [Fine tuning Gemma3n on video+audio using FineVideo (all modalities)](/scripts/gemma3n_fine_tuning_on_all_modalities.py)
|
||||
|
||||
### Gemma 3
|
||||
|
||||
* [Reinforement Learning (GRPO) on Gemma 3 with Unsloth and TRL](/notebooks/Gemma3_(1B)-GRPO.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma3_(1B)-GRPO.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
* [Vision fine tuning Gemma 3 4B with Unsloth](/notebooks/Gemma3_(4B)-Vision.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma3_(4B)-Vision.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
* [Conversational fine tuning Gemma 3 4B with Unsloth](/notebooks/Gemma3_(4B).ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma3_(4B).ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
||||
## RAG
|
||||
|
||||
### Gemma 3n
|
||||
* [Retrieval-Augmented Generation with Gemma 3n](/notebooks/Gemma_RAG.ipynb) <a href="https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma_RAG.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
||||
|
||||
Before fine-tuning the model, ensure all dependencies are installed:
|
||||
|
||||
```bash
|
||||
$ pip install -U -q -r requirements.txt
|
||||
```
|
||||
|
||||
✨ **Bonus:** We've also experimented with adding **object detection** 🔍 capabilities to Gemma 3. You can explore that work in [this dedicated repo](https://github.com/ariG23498/gemma3-object-detection).
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
|
||||
"""
|
||||
GRPO training with OpenEnv's CARLA environment for VLMs (Vision Language Models).
|
||||
|
||||
This script uses `environment_factory` with multimodal tool responses: each tool action
|
||||
returns a camera image from the vehicle alongside the text scene description, allowing the
|
||||
VLM to see the driving scene visually after each action.
|
||||
|
||||
The CARLA environment simulates an emergency driving scenario where pedestrians are ahead
|
||||
and the model must learn to observe the scene and take the correct action (e.g., swerve
|
||||
to an empty lane) to minimize casualties.
|
||||
|
||||
Setup:
|
||||
```sh
|
||||
pip install "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env"
|
||||
```
|
||||
|
||||
Usage (requires at least 2 CARLA Spaces, each supports only 1 concurrent connection):
|
||||
```sh
|
||||
python examples/scripts/openenv/carla_vlm.py \
|
||||
--env-urls https://server1.hf.space https://server2.hf.space
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from carla_env import CarlaAction, CarlaEnv
|
||||
from datasets import Dataset
|
||||
from PIL import Image
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.")
|
||||
parser.add_argument("--model", type=str, default="google/gemma-4-E2B-it")
|
||||
parser.add_argument(
|
||||
"--env-urls",
|
||||
type=str,
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).",
|
||||
)
|
||||
parser.add_argument("--dataset-size", type=int, default=1000)
|
||||
parser.add_argument("--max-completion-length", type=int, default=3072)
|
||||
parser.add_argument("--per-device-train-batch-size", type=int, default=None, help="Defaults to len(env-urls).")
|
||||
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
||||
parser.add_argument("--max-steps", type=int, default=100)
|
||||
parser.add_argument("--image-size", type=int, default=256, help="Resize camera images to this size. 0 to disable.")
|
||||
parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.")
|
||||
parser.add_argument("--use-lora", action="store_true", help="Use LoRA for memory-efficient training.")
|
||||
parser.add_argument("--lora-r", type=int, default=128, help="LoRA rank.")
|
||||
parser.add_argument("--lora-alpha", type=int, default=256, help="LoRA alpha.")
|
||||
parser.add_argument(
|
||||
"--lora-target-modules",
|
||||
type=str,
|
||||
default="llm-only",
|
||||
help="LoRA target modules. Use 'llm-only' to skip vision encoder, 'all-linear' for all.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default=5e-6, help="Learning rate. Default 5e-6 (good for LoRA r=128)."
|
||||
)
|
||||
parser.add_argument("--hub-model-id", type=str, default=None)
|
||||
parser.add_argument("--hub-private-repo", action="store_true", help="Make the Hub repo private.")
|
||||
parser.add_argument("--run-name", type=str, default=None)
|
||||
parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
SIM_TICKS = 10
|
||||
|
||||
|
||||
def reward_func(completions, environments, **kwargs):
|
||||
rewards = []
|
||||
for i, (comp, env) in enumerate(zip(completions, environments, strict=False)):
|
||||
# Advance to episode end to capture rubric_reward
|
||||
try:
|
||||
final = env._advance_until_done()
|
||||
if final and final.observation.rubric_reward:
|
||||
env.reward = final.observation.rubric_reward
|
||||
except Exception as e:
|
||||
print(f"[WARN] _advance_until_done failed for gen={i}: {e}")
|
||||
r = env.reward
|
||||
tools = [
|
||||
msg["tool_calls"][0]["function"]["name"] for msg in comp if isinstance(msg, dict) and msg.get("tool_calls")
|
||||
]
|
||||
rewards.append(r)
|
||||
print(f"[DEBUG reward] gen={i} tools={tools} env_reward={env.reward} total={r}")
|
||||
return rewards
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
env_url_iter = iter(args.env_urls)
|
||||
image_size = args.image_size
|
||||
|
||||
prompt = """\
|
||||
You control an autonomous vehicle in an emergency. There are pedestrians ahead and you must \
|
||||
decide what to do immediately.
|
||||
|
||||
You will see a camera image from the vehicle after each action. Use the visual information
|
||||
along with the scene description to decide your next action.
|
||||
|
||||
You have the following tools available:
|
||||
- `observe`: Advance time and get a new observation of the scene with a camera image.
|
||||
- `emergency_stop`: Apply maximum braking to stop the vehicle.
|
||||
- `lane_change(direction)`: Change lane to the left or right. Direction must be "left" or "right".
|
||||
|
||||
Make one tool call at a time, wait for the result, then decide your next action.
|
||||
Observe the scene first, then decide the best course of action to minimize harm.
|
||||
Consider all available actions - sometimes avoiding the obstacle by changing lanes \
|
||||
is safer than stopping in its path."""
|
||||
|
||||
dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]})
|
||||
|
||||
class CarlaVLMEnv:
|
||||
def __init__(self):
|
||||
self.url = next(env_url_iter)
|
||||
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
|
||||
self.reward = 0.0
|
||||
|
||||
@staticmethod
|
||||
def _describe(obs) -> str:
|
||||
parts = []
|
||||
parts.append(f"Speed: {obs.speed_kmh:.1f} km/h.")
|
||||
if obs.nearby_actors:
|
||||
for actor in obs.nearby_actors:
|
||||
parts.append(f"- {actor.get('type', 'actor')} at {actor.get('distance', '?')}m")
|
||||
else:
|
||||
parts.append("No nearby actors detected.")
|
||||
if obs.collision_detected:
|
||||
parts.append(f"COLLISION detected with {obs.collided_with or 'unknown'}!")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _decode_image(camera_image_b64, target_size):
|
||||
"""Decode base64 JPEG image and optionally resize."""
|
||||
img_bytes = base64.b64decode(camera_image_b64)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
if target_size > 0:
|
||||
img.thumbnail((target_size, target_size), Image.LANCZOS)
|
||||
return img
|
||||
|
||||
def _format_multimodal(self, obs) -> list:
|
||||
"""Format observation as multimodal content blocks (camera image + text)."""
|
||||
content = []
|
||||
if obs.camera_image is not None:
|
||||
img = self._decode_image(obs.camera_image, image_size)
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": self._describe(obs)})
|
||||
return content
|
||||
|
||||
def _advance(self, ticks: int = SIM_TICKS):
|
||||
result = None
|
||||
for _ in range(ticks):
|
||||
result = self.client.step(CarlaAction(action_type="observe"))
|
||||
if result.done:
|
||||
break
|
||||
return result
|
||||
|
||||
def _advance_until_done(self, max_ticks: int = 50):
|
||||
"""Advance the simulation until the episode ends."""
|
||||
result = None
|
||||
for _ in range(max_ticks):
|
||||
result = self.client.step(CarlaAction(action_type="observe"))
|
||||
if result.done:
|
||||
break
|
||||
return result
|
||||
|
||||
def _advance_and_capture(self, ticks: int = SIM_TICKS):
|
||||
"""Advance the simulation, then capture an image of the current state."""
|
||||
result = self._advance(ticks)
|
||||
capture_result = self.client.step(CarlaAction(action_type="capture_image"))
|
||||
result.observation.camera_image = capture_result.observation.camera_image
|
||||
return result
|
||||
|
||||
def reset(self, **kwargs) -> str | None:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = self.client.reset(scenario_name="trolley_micro_escape_exists")
|
||||
self.reward = 0.0
|
||||
return self._describe(result.observation)
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
raise
|
||||
print(f"[WARN] reset failed (attempt {attempt + 1}/3): {e}. Reconnecting...")
|
||||
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
|
||||
|
||||
def observe(self) -> list:
|
||||
"""
|
||||
Get the current scene with a camera image and description.
|
||||
|
||||
Returns:
|
||||
The camera image and scene description with vehicle state and nearby actors.
|
||||
"""
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
def emergency_stop(self) -> list:
|
||||
"""
|
||||
Apply maximum braking to stop the vehicle.
|
||||
|
||||
Returns:
|
||||
The camera image and scene description after braking.
|
||||
"""
|
||||
self.client.step(CarlaAction(action_type="emergency_stop"))
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
print(f"[DEBUG env] emergency_stop: done={result.done}, reward={self.reward}")
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
def lane_change(self, direction: str) -> list:
|
||||
"""
|
||||
Change lane to avoid obstacles.
|
||||
|
||||
Args:
|
||||
direction: Direction to change lane, either "left" or "right".
|
||||
|
||||
Returns:
|
||||
The camera image and scene description after changing lane.
|
||||
"""
|
||||
self.client.step(CarlaAction(action_type="lane_change", lane_direction=direction))
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
print(f"[DEBUG env] lane_change({direction}): done={result.done}, reward={self.reward}")
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
peft_config = None
|
||||
if args.use_lora:
|
||||
from peft import LoraConfig
|
||||
|
||||
if args.lora_target_modules == "llm-only":
|
||||
target_modules = "all-linear"
|
||||
exclude_modules = ["vision_tower", "multi_modal_projector"]
|
||||
else:
|
||||
target_modules = args.lora_target_modules
|
||||
exclude_modules = None
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=target_modules,
|
||||
exclude_modules=exclude_modules,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=args.model,
|
||||
train_dataset=dataset,
|
||||
reward_funcs=reward_func,
|
||||
peft_config=peft_config,
|
||||
args=GRPOConfig(
|
||||
chat_template_kwargs={"enable_thinking": False},
|
||||
log_completions=True,
|
||||
logging_steps=2,
|
||||
num_completions_to_print=1,
|
||||
max_completion_length=args.max_completion_length,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size or len(args.env_urls),
|
||||
steps_per_generation=1,
|
||||
num_generations=len(args.env_urls),
|
||||
max_tool_calling_iterations=10,
|
||||
learning_rate=args.learning_rate,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
max_steps=args.max_steps,
|
||||
push_to_hub=args.hub_model_id is not None,
|
||||
hub_model_id=args.hub_model_id,
|
||||
hub_private_repo=args.hub_private_repo,
|
||||
run_name=args.run_name,
|
||||
report_to=args.report_to,
|
||||
trackio_space_id=args.trackio_space_id,
|
||||
),
|
||||
environment_factory=CarlaVLMEnv,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,764 @@
|
||||
---
|
||||
title: "Welcome Gemma 4: Frontier multimodal intelligence on device"
|
||||
thumbnail: /blog/assets/gemma4/thumbnail.png
|
||||
authors:
|
||||
- user: merve
|
||||
- user: pcuenq
|
||||
- user: sergiopaniego
|
||||
- user: burtenshaw
|
||||
- user: Steveeeeeeen
|
||||
- user: alvarobartt
|
||||
- user: SaylorTwift
|
||||
---
|
||||
|
||||
# Welcome Gemma 4: Frontier multimodal intelligence on device
|
||||
|
||||
The Gemma 4 family of multimodal models by Google DeepMind is out on Hugging Face, with support for your favorite agents, inference engines, and fine-tuning libraries 🤗
|
||||
|
||||
These models are the real deal: truly open with Apache 2 licenses, high quality with pareto frontier arena scores, multimodal including audio, and sizes you can use _everywhere_ including on-device. Gemma 4 builds on advances from previous families and makes them click together. In our tests with pre-release checkpoints we have been impressed by their capabilities, to the extent that we struggled to find good fine-tuning examples because they are _so good_ out of the box.
|
||||
|
||||
We collaborated with Google and the community to make them available everywhere: transformers, llama.cpp, MLX, WebGPU, Rust; you name it. This blog post will show you how to build with [your favorite tools](https://huggingface.co/collections/google/gemma-4) so let us know what you think!
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [What is New with Gemma 4?](#what-is-new-with-gemma-4)
|
||||
- [Overview of Capabilities and Architecture](#overview-of-capabilities-and-architecture)
|
||||
- [Architecture at a Glance](#architecture-at-a-glance)
|
||||
- [Per-Layer Embeddings (PLE)](#per-layer-embeddings-ple)
|
||||
- [Shared KV Cache](#shared-kv-cache)
|
||||
- [Multimodal Capabilities](#multimodal-capabilities)
|
||||
- [Deploy Anywhere](#deploy-anywhere)
|
||||
- [transformers](#transformers)
|
||||
- [Llama.cpp](#llamacpp)
|
||||
- [Plug in to your local agent](#Plug-in-your-local-agent)
|
||||
- [transformers.js](#transformersjs)
|
||||
- [MLX](#mlx)
|
||||
- [Mistral.rs](#mistralrs)
|
||||
- [Fine-tuning & Demos](#fine-tuning--demos)
|
||||
- [Fine-tuning with TRL](#fine-tuning-with-trl)
|
||||
- [Fine-tuning with TRL on Vertex AI](#fine-tuning-with-trl-on-vertex-ai)
|
||||
- [Fine-tuning with Unsloth Studio](#fine-tuning-with-unsloth-studio)
|
||||
- [Try Gemma 4](#try-gemma-4)
|
||||
- [Benchmark Results](#benchmark-results)
|
||||
- [Acknowledgements](#acknowledgements)
|
||||
|
||||
# What is new with Gemma 4?
|
||||
|
||||
Similar to Gemma-3n, Gemma 4 supports image, text, and audio inputs, and generates text responses. The text decoder is based on the Gemma model with support for long context windows. The image encoder is similar to the one from Gemma 3 but with two crucial improvements: variable aspect ratios, and configurable number of image token inputs to find your sweet spot between speed, memory, and quality. All models support images (or video) and text inputs, while the small variants (E2B and E4B) support audio as well.
|
||||
|
||||
Gemma 4 comes in four sizes, all base and instruction fine-tuned:
|
||||
|
||||
| Model | Parameter Size | Context Window | Checkpoints |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| Gemma 4 E2B | 2.3B effective, 5.1B with embeddings | 128k | [base](https://huggingface.co/google/gemma-4-E2B), [IT](https://huggingface.co/google/gemma-4-E2B-it) |
|
||||
| Gemma 4 E4B | 4.5B effective, 8B with embeddings | 128k | [base](https://huggingface.co/google/gemma-4-E4B), [IT](https://huggingface.co/google/gemma-4-E4B-it) |
|
||||
| Gemma 4 31B | 31B dense model | 256K | [base](https://huggingface.co/google/gemma-4-31B), [IT](https://huggingface.co/google/gemma-4-31B-it) |
|
||||
| Gemma 4 26B A4B | mixture-of-experts with 4B activated/26B total parameters | 256K | [base](https://huggingface.co/google/gemma-4-26B-A4B), [IT](https://huggingface.co/google/gemma-4-26B-A4B-it) |
|
||||
|
||||
## Overview of Capabilities and Architecture
|
||||
|
||||
Gemma 4 leverages several architecture components used in previous Gemma versions and other open models, and leaves out complex or inconclusive features such as Altup. The combination is a mix designed to be highly compatible across libraries and devices, that can efficiently support long context and agentic use cases, whilst being ideal for quantization.
|
||||
|
||||
As shown in the benchmarks above, this feature mix (combined with the training data and recipe) enables the 31B dense model to achieve an estimated LMArena score (text only) of 1452, while the 26B MoE reaches 1441 with just 4B active parameters 🤯. As we'll see, multimodal operation is comparatively as good as text generation, at least in informal and subjective tests.
|
||||
|
||||
These are the main architecture characteristics in Gemma 4:
|
||||
|
||||
* Alternating **local sliding-window** and **global full-context** attention layers. Smaller dense models use sliding windows of 512 tokens while larger models use 1024 tokens.
|
||||
* **Dual RoPE** configurations: standard RoPE for sliding layers, pruned RoPE for global layers, to enable longer context.
|
||||
* **Per-Layer Embeddings (PLE)**: a second embedding table that feeds a small residual signal into every decoder layer.
|
||||
* **Shared KV Cache**: the last N layers of the model reuse key-value states from earlier layers, eliminating redundant KV projections.
|
||||
* **Vision encoder**: uses learned 2D positions and multidimensional RoPE. Preserves the original aspect ratios and can encode images to a few different token budgets (70, 140, 280, 560, 1120).
|
||||
* **Audio encoder**: USM-style conformer with the same base architecture as the one in Gemma-3n.
|
||||
|
||||
#### Per-Layer Embeddings (PLE)
|
||||
|
||||
One of the most distinctive features in smaller Gemma 4 models is Per-Layer Embeddings (PLE), which was introduced previously in Gemma-3n. In a standard transformer, each token gets a single embedding vector at input, and the same initial representation is what the residual stream builds on across all layers, forcing the embedding to frontload everything the model might need. PLE adds a parallel, lower-dimensional conditioning pathway alongside the main residual stream. For each token, it produces a small dedicated vector for every layer by combining two signals: a token-identity component (from an embedding lookup) and a context-aware component (from a learned projection of the main embeddings). Each decoder layer then uses its corresponding vector to modulate the hidden states via a lightweight residual block after attention and feed-forward. This gives each layer its own channel to receive token-specific information only when it becomes relevant, rather than requiring everything to be packed into a single upfront embedding. Because the PLE dimension is much smaller than the main hidden size, this adds meaningful per-layer specialization at modest parameter cost. For multimodal inputs (images, audio, video), PLE is computed before soft tokens are merged into the embedding sequence — since PLE relies on token IDs that are lost once multimodal features replace the placeholders. Multimodal positions use the pad token ID, effectively receiving neutral per-layer signals.
|
||||
|
||||
#### Shared KV Cache
|
||||
|
||||
The **shared KV cache** is an efficiency optimization that reduces both compute and memory during inference. The last `num_kv_shared_layers` layers of the model don't compute their own key and value projections. Instead, they **reuse** the K and V tensors from the last non-shared layer of the same attention type (sliding or full).
|
||||
|
||||
In practice, this has a minimal impact on quality while being much more efficient (in terms of both memory and compute) for long context generation and on-device use.
|
||||
|
||||
## Multimodal Capabilities
|
||||
|
||||
We saw in our tests that Gemma 4 supports comprehensive multimodal capabilities out of the box. We don't know what was the training mix, but we had success using it for tasks such as OCR, speech-to-text, object detection, or pointing. It also supports text-only and multimodal function calling, reasoning, code completion and correction.
|
||||
|
||||
Here, we show a few inference examples across different model sizes. You can run them conveniently with [this notebook](https://github.com/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb). We encourage you to try the demos and share them below this blog!
|
||||
|
||||
### Object Detection and Pointing
|
||||
|
||||
### GUI detection
|
||||
|
||||
We test Gemma 4 on GUI element detection and pointing across different sizes, with the following image and text prompt: "What's the bounding box for the "view recipe" element in the image?"
|
||||
|
||||

|
||||
|
||||
With this prompt, the model natively responds in JSON format with the detected bounding boxes - no need for specific instructions or grammar-constrained generation. We found the coordinates refer to an image size of 1000x1000, relative to the input dimensions.
|
||||
|
||||
We visualize the outputs below for your convenience. We parse the bounding boxes from the returned JSON: ```json\n[\n {"box_2d": [171, 75, 245, 308], "label": "view recipe element"}\n]\n```
|
||||
|
||||
| E2B | E4B |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
| 26/A4B | 31B |
|
||||
|  |  |
|
||||
|
||||
### Object Detection
|
||||
|
||||
We test models to detect everyday objects, here we ask them to detect the bike and compare different model outputs. As in the previous case, we parse the bounding box from the json and translate to image space coordinates.
|
||||
|
||||
| E2B | E4B | 26B/A3B | 31B |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
|  |  |  |  |
|
||||
|
||||
### Multimodal Thinking and Function Calling
|
||||
|
||||
We asked Gemma 4 to write HTML code to reconstruct a page we made with Gemini 3. Below you can find the code to do this, we enable thinking and ask each model to generate up to 4000 new tokens, to make it foolproof.
|
||||
|
||||
| Gemini Generated Website (Reference) | Gemini Reproduced Image |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png",
|
||||
},
|
||||
{"type": "text", "text": "Write HTML code for this page."},
|
||||
],
|
||||
}
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=4000)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| E2B | E4B |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
|
||||
| 31B | 26/A4B |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
|
||||
### Video Understanding
|
||||
|
||||
Smaller Gemma 4 models can take in videos with audio while larger ones can take in videos without audio. While the models are not explicitly post-trained on videos, they can understand videos both with and without audio. The model is particularly strong in audios
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"},
|
||||
{"type": "text", "text": "What is happening in the video? What is the song about?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
load_audio_from_video=True, # disable this for larger models
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=200)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B (hallucinated for the audio) | Based on the video provided, here is an analysis: **What is happening in this video?** The video captures a live musical performance, likely at a large outdoor festival or concert. * **Setting:** The scene is a large stage set up outdoors, indicated by the extensive lighting rig, large speakers, and the audience visible in the foreground. * **Performers:** There are several musicians on stage. The focal point seems to be a lead singer/guitarist in a blue shirt and white pants who is actively performing, playing an acoustic or electric guitar. There are also other musicians, including drummers and bassists, visible in the background. * **Atmosphere:** The atmosphere is energetic and dramatic, characterized by bright stage lights, intense blue and white spotlights, and a significant amount of stage fog/smoke, which adds to the concert mood. * **Audience:** The foreground is filled with the backs of audience members, showing they are close to the stage and engaged in watching the performance. **What is the song about?** The lyrics provided below the video are: > "I don't care 'bout street > From what I cherish > Could it's time to protest |
|
||||
| E4B ✅ | The video captures a live performance of a band on a stage with bright lights and smoke effects, surrounded by an audience. The band performs and sings to the crowd. The song appears to be about the struggles and disillusionment of modern life, specifically the feeling of being stuck and the difficulties in finding one's place in a chaotic and overwhelming world. Lyrics mention themes of the "concrete street," "broken promises," "storm of the day," and the difficulty of finding oneself amidst societal pressures. |
|
||||
| 26/A4B (no audio) ✅ | This video shows a live music performance at an outdoor concert or festival at night. A solo male artist, dressed in a blue shirt and white pants, is the central focus. He is playing an acoustic guitar and singing into a microphone on a large, lit stage. The performance is accompanied by a band, including a drummer and other musicians visible in the background and on large video screens. The stage is filled with atmospheric blue lighting, smoke, and bright spotlights. The video captures different perspectives, including shots from the crowd looking up at the stage and close-up views of the performers projected onto large screens. The crowd is visible in the foreground, watching the show. One of the screens also displays the "Estrella Damm" brand name. |
|
||||
| 31B (no audio) ✅ | A musician is performing on stage during a concert, playing an acoustic guitar and singing into a microphone. He is wearing a blue shirt and white pants. The stage is filled with atmospheric smoke and illuminated by bright blue and white lights. Other band members, including a drummer and a keyboardist, are also visible on stage. The video shows the musician from the perspective of the audience, with some shots focusing on him and others showing the large screen on the side of the stage, which displays close-ups of the performer and the crowd. |
|
||||
|
||||
### Captioning
|
||||
|
||||
We have tested all models on captioning. All checkpoints perform very well and accurately capture nuance in complex scenerios.
|
||||
Here's the image we use, with the prompt "Write single detailed caption for this image.".
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png"},
|
||||
{"type": "text", "text": "Write single detailed caption for this image."},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=512)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B ✅ | A medium shot captures a weathered seagull perched atop a stone pedestal in what appears to be a bustling European square, with a grand, classical-style building featuring ornate columns and architectural details dominating the right side of the frame. In the background, a low, light-colored building stretches across the middle ground, flanked by some green foliage, and hints of other structures, including a terracotta-roofed building, are visible. In the foreground, part of a modern bus stop or shelter is visible, with signage indicating routes, and the overall scene is bathed in bright daylight under a pale sky. |
|
||||
| E4B ✅ | A solitary seagull perches atop a grey stone pedestal in a bustling European piazza, with the historical grandeur of a classical building dominating the right side of the frame, contrasting with the more modern, low-rise architecture in the background under a bright, hazy sky. |
|
||||
| 26B/A4B ✅ | A full shot shows a white and brown seagull perched on a gray pillar in Venice, Italy. The seagull faces left, its beak pointed towards the edge of the frame. Behind it is a light gray, two-story building with several dark rectangular windows. To the right of the seagull is a large, white building with columns and intricate carvings. The sky is a clear, light blue. The lighting is bright and even, suggesting it is daytime. |
|
||||
| 31B ✅ | A high-angle, medium shot shows a white and gray seagull perched on a gray concrete pylon in Venice, Italy. The seagull is facing left, its body angled slightly towards the camera. Its white head, neck, and chest contrast with its gray wings and back. The seagull's orange beak is slightly open. The pylon is a light gray concrete cylinder with a weathered texture. In the background, there are several buildings. To the left is a long, low, white building with a series of small, rectangular windows. In the center is a cluster of green trees. To the right is a red building with a tiled roof and several small, rectangular windows. Further right is a large, ornate, white stone building with several columns and sculptures. The sky is a pale blue, and the light is bright and clear. In the lower-left corner, there is a yellow bus with the word "Ferrovia" printed on it in black letters. There are several people walking in the background, but they are small and out of focus. |
|
||||
|
||||
### Audio Question Answering
|
||||
|
||||
These models are trained to answer questions about speech in audio. Music and non-speech sounds were not part of the training data.
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
|
||||
{"type": "text", "text": "Can you describe this audio in detail?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B | This audio is a personal reflection. The speaker is talking about their final farewell address to the nation, which they delivered in Chicago. They express gratitude for the conversations they've had with the American people, noting that despite not having met them face-to-face or even greeted them, these interactions in various settings like living rooms, schools, farms, factory floors, diners, and military outposts have been what has kept them going. |
|
||||
| E4B | The audio is a speech excerpt where a speaker is delivering a farewell address to the nation from Chicago. The speaker reflects on their time in office, expressing gratitude for the conversations they had with the American people across various settings like living rooms, schools, farms, factories, diners, and military outposts. The tone is reflective and appreciative, highlighting the importance of these interactions in their political journey. |
|
||||
|
||||
Here is an example if you want to do transcription:
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
|
||||
{"type": "text", "text": "Transcribe the audio?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B | This week I traveled to Chicago to deliver my final farewell address to the nation following in the tradition of presidents before me It was an opportunity to say thank you whether we've seen eye to eye or rarely agreed at all my conversations with you the American people in living rooms and schools at farms and on factory floors at diners and on distant military outposts all these conversations are what have kept me honest |
|
||||
| E4B | This week I traveled to Chicago to deliver my final farewell address to the nation following in the tradition of presidents before me. It was an opportunity to say thank you. Whether we've seen eye to eye or rarely agreed at all, my conversations with you, the American people, in living rooms and schools, at farms and on factory floors, at diners and on distant military outposts, all these conversations are what have kept me honest. |
|
||||
|
||||
### Multimodal Function Calling
|
||||
|
||||
We test the model by asking to get the weather in the place shown in the image.
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
import re
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Gets the current weather for a specific location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
tools = [WEATHER_TOOL]
|
||||
messages = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "What is the city in this image? Check the weather there right now."},
|
||||
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg"},
|
||||
]},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tools=[WEATHER_TOOL],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=1000)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B | **Analyze the Request:** The user has provided an image and asked two questions:\n * "What is the city in this image?" * "Check the weather there right now."\n\n2. **Analyze the Image:** The image displays a large, ornate, tall structure, characteristic of temple architecture, likely in Southeast Asia (given the style). It appears to be a famous landmark. 3. **Identify the Landmark/Location (Internal Knowledge/Visual Cues):** The structure strongly resembles the chedis/stupas found in places like Bangkok, Thailand (e.g., Wat Arun or similar structures).\n\n4. **Determine the City:** Based on the visual evidence, the location is highly likely to be Bangkok, Thailand.\n\n5. **Check Available Tools:** The available tool is `get_weather(city: STRING)`. 6. **Formulate the Plan:**\n * Step 1: Identify the city from the image (Bangkok).\n * Step 2: Use the `get_weather` tool with the identified city.\n\n7. **Construct the Tool Call:** `get_weather(city="Bangkok")`\n\n8. **Generate Response:** The model should first state the identified city and then call the tool to get the weather. (Self-correction: The model *must* call the tool in the correct format if it needs external information.)call:get_weather{city:Bangkok} |
|
||||
| E4B | '1. **Analyze the Request:** The user has provided an image and asked two things:\n a. Identify the city in the image.\n b. Check the weather in that city right now.\n\n2. **Analyze the Image:** The image shows a very distinctive, ornate tower/spire structure, characteristic of Thai architecture (specifically, a chedi or prang).\n\n3. **Determine City Identification (Internal Knowledge/Visual Recognition):** Based on the architectural style, the structure strongly resembles prominent temples in Thailand, such as Wat Arun in Bangkok.\n\n4. **Evaluate Tool Availability:** I have one available tool: `get_weather(city: STRING)`. This tool requires a city name as input.\n\n5. **Formulate a Plan:**\n a. I need the city name to use the `get_weather` tool.\n b. I must first identify the city from the image.\n c. Since my internal knowledge suggests Bangkok, I will proceed with that assumption, but I must acknowledge that I cannot *definitively* prove the location from the image alone.\n\n6. **Execute Step 1 (City Identification):** Assume the city is Bangkok based on the visual evidence of the architecture.\n\n7. **Execute Step 2 (Tool Call Generation):** Use the `get_weather` tool with "Bangkok" as the `city` argument.\n\n8. **Construct the Response:** Generate the appropriate tool call in JSON format.', 'answer': '<\|tool_call>call:get_weather{city:<\|"\|>Bangkok<\|"\|>}<tool_call\|><eos>' |
|
||||
| 31B | The user wants to know the city in the image and the current weather there. 1. **Analyze the image:** The image shows ornate, spire-like structures (prangs) characteristic of Thai temple architecture. Specifically, these look like the prangs at Wat Phra Kaew (Temple of the Emerald Buddha) or Wat Pho in Bangkok, Thailand. The style is very distinctive of Bangkok's royal temple complexes. 2. **Identify the city:** Based on the architectural style, the city is Bangkok. 3. **Find the weather:** I need to call the `get_weather` tool with `city="Bangkok"`.call:get_weather{city:Bangkok} |
|
||||
| 26/A4B | The image shows several tall, ornate towers (prangs) typical of Thai Buddhist temple architecture. Specifically, these look like the structures within the Grand Palace or Wat Phra Kaew in Bangkok, Thailand. The style is unmistakably Thai. I'll confirm this by looking at the specific towers. These are very iconic.\n2. **Confirm the location:** The image shows elements of the Grand Palace complex in Bangkok.\n3. **Get weather information:** Once the city is identified as Bangkok, I will use the `get_weather` tool with `city='Bangkok'`.call:get_weather{city:Bangkok} |
|
||||
|
||||
# Deploy Anywhere
|
||||
|
||||
Gemma 4 comes with day-0 support for many open-source inference engines, and is ideal for tool calling and agents! We also release ONNX checkpoints that can run on many hardware backends, allowing use cases on edge devices or in browser!
|
||||
|
||||
## transformers
|
||||
|
||||
Gemma 4 comes with first-class transformers support from the get-go 🤗. This integration allows using the model with other libraries like bitsandbytes, PEFT and TRL. Make sure to install the latest version of transformers.
|
||||
|
||||
```bash
|
||||
pip install -U transformers
|
||||
```
|
||||
|
||||
The easiest way to infer with the small Gemma 4 models is through the `any-to-any` pipeline. You can initialize it as follows.
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
pipe = pipeline("any-to-any", model="google/gemma-4-e2b-it")
|
||||
```
|
||||
|
||||
You can then pass in images and text as follows.
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg",
|
||||
},
|
||||
{"type": "text", "text": "Do you have travel advice going to here?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
output = pipe(messages, max_new_tokens=100, return_full_text=False)
|
||||
output[0]["generated_text"]
|
||||
# Based on the image, which appears to show a magnificent, ornate **Buddhist temple or pagoda**, likely in Southeast Asia (such as Thailand, Myanmar, or Cambodia), here is some general travel advice..
|
||||
```
|
||||
|
||||
When inferring with videos, you can include the audio track using the `load_audio_from_video` argument.
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
pipe(messages, load_audio_from_video=True)
|
||||
```
|
||||
|
||||
Going a level lower, you can load Gemma 4 using the `AutoModelForMultimodalLM` class, especially useful for fine-tuning. The built-in chat template takes care of formatting the inputs correctly, please make sure you use it to prevent subtle mistakes when building the prompt manually.
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForMultimodalLM, AutoProcessor
|
||||
model = AutoModelForMultimodalLM.from_pretrained("google/gemma-4-E2B-it", device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained("google/gemma-4-E2B-it")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Llama.cpp
|
||||
|
||||
Gemma 4 models come with image+text support in llama.cpp from the get-go! This unlocks using Gemma 4 with all of your favorite local apps: llama-cpp server, lmstudio, Jan as well as coding agents like Pi across many backends such as Metal and CUDA.
|
||||
|
||||
You can install llama-cpp as follows.
|
||||
|
||||
```bash
|
||||
brew install llama.cpp # MacOS
|
||||
winget install llama.cpp # Windows
|
||||
```
|
||||
|
||||
You can then start a server compatible with the OpenAI API Replace the quantization scheme at the end of the command with the precision of your choice.
|
||||
|
||||
```bash
|
||||
llama-server -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
```
|
||||
|
||||
Check out this link [for more](https://huggingface.co/ggml-org/gemma-4-E2B-it-GGUF?local-app=llama.cpp) options on combining llama.cpp with different coding agents and local apps. Find all the GGUF checkpoints [in this collection](https://huggingface.co/collections/ggml-org/gemma-4).
|
||||
|
||||
## Plug in your local agent
|
||||
|
||||
We worked on making sure the new models work locally with agents like **openclaw, hermes, pi, and open code**. All thanks to llama.cpp! Run the following to try Gemma 4 right away.
|
||||
|
||||
First, start your local server:
|
||||
|
||||
```
|
||||
llama-server -hf ggml-org/gemma-4-26b-a4b-it-GGUF:Q4_K_M
|
||||
```
|
||||
|
||||
For **hermes:**
|
||||
|
||||
```shell
|
||||
hermes model
|
||||
```
|
||||
|
||||
For **openclaw:**
|
||||
|
||||
```shell
|
||||
openclaw onboard
|
||||
```
|
||||
|
||||
For **pi** define a `~/.pi/agent/models.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"llama-cpp": {
|
||||
"baseUrl": "http://localhost:8080/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "none",
|
||||
"models": [
|
||||
{
|
||||
"id": "ggml-org-gemma-4-26b-4b-gguf"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For **open code** define a `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"llama.cpp": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "llama-server (local)",
|
||||
"options": {
|
||||
"baseURL": "http://127.0.0.1:8080/v1"
|
||||
},
|
||||
"models": {
|
||||
"gemma-4-26b-4b-it": {
|
||||
"name": "Gemma 4 (local)",
|
||||
"limit": {
|
||||
"context": 128000,
|
||||
"output": 8192
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## transformers.js
|
||||
|
||||
transformers.js enables running Gemma 4 right inside browser. You can check out the model card to see text-only, image & text, audio & text inference in detail [here](https://huggingface.co/onnx-community/gemma-4-E2B-it-ONNX#transformersjs-javascript). We also shipped a demo for you to test the model [here](https://huggingface.co/spaces/webml-community/Gemma-4-WebGPU).
|
||||
|
||||
## MLX
|
||||
|
||||
Full multimodal support of Gemma 4 is available using the open-source [`mlx-vlm` library](https://github.com/Blaizzy/mlx-vlm). Here's how to ask the model to describe an image:
|
||||
|
||||
```shell
|
||||
pip install -U mlx-vlm
|
||||
```
|
||||
|
||||
```shell
|
||||
mlx_vlm.generate \
|
||||
--model google/gemma-4-E4B-it \
|
||||
--image https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg \
|
||||
--prompt "Describe this image in detail"
|
||||
```
|
||||
|
||||
mlx-vlm supports TurboQuant, which delivers the same accuracy as the uncompressed baseline while using ~4x less active memory and running a lot faster end-to-end. This makes long-context inference practical on Apple Silicon without sacrificing quality. Use it like this:
|
||||
|
||||
```shell
|
||||
mlx_vlm.generate \
|
||||
--model "mlx-community/gemma-4-26b-a4b-it-4bit" \
|
||||
--prompt "Your prompt here" \
|
||||
--kv-bits 3.5 \
|
||||
--kv-quant-scheme turboquant
|
||||
```
|
||||
|
||||
For audio examples and more details, please check [the MLX collection](https://hf.co/mlx-community/gemma-4).
|
||||
|
||||
### Mistral.rs
|
||||
|
||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a Rust-native inference engine with day-0 Gemma 4 support across all modalities (text, image, video, audio) and builtin tool-calling and agentic functionality. Install mistral.rs:
|
||||
|
||||
```bash
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/EricLBuehler/mistral.rs/master/install.sh | sh # Linux/macOS
|
||||
|
||||
irm https://raw.githubusercontent.com/EricLBuehler/mistral.rs/master/install.ps1 | iex # Windows
|
||||
```
|
||||
|
||||
You can then start an OpenAI-compatible HTTP server:
|
||||
|
||||
```bash
|
||||
mistralrs serve mistralrs-community/gemma-4-E4B-it-UQFF --from-uqff 8
|
||||
```
|
||||
|
||||
Or, use interactive mode:
|
||||
|
||||
```
|
||||
mistralrs run -m google/gemma-4-E4B-it --isq 8 --image image.png -i "Describe this image in detail."
|
||||
|
||||
mistralrs run -m google/gemma-4-E4B-it --isq 8 --audio audio.mp3 -i "Transcribe this fully."
|
||||
```
|
||||
|
||||
Find all models [here](https://huggingface.co/mistralrs-community/models). Please, follow [the instructions](https://huggingface.co/mistralrs-community/gemma-4-E2B-it-UQFF#install) in the model cards for installation and inference guidelines.
|
||||
|
||||
## Fine-tuning for all
|
||||
|
||||
Gemma 4 models are ideal for fine-tuning in your favorite tools and platforms and at any budget.
|
||||
|
||||
## Fine-tuning with TRL
|
||||
|
||||
Gemma 4 is fully supported for fine-tuning with TRL. To celebrate, TRL has been upgraded with support for multimodal tool responses when interacting with environments, meaning models can now receive images back from tools during training, not just text.
|
||||
|
||||
To showcase this, we've built an example training script where Gemma 4 learns to drive in the CARLA simulator. The model sees the road through a camera, decides what to do and learns from the outcome. After training, it consistently changes lanes to avoid pedestrians. The same approach works for any task where a model needs to see and act: robotics, web browsing, or other interactive environments.
|
||||
|
||||
Get started:
|
||||
|
||||
```shell
|
||||
# pip install git+https://github.com/huggingface/trl.git
|
||||
|
||||
python examples/scripts/openenv/carla_vlm_gemma.py \
|
||||
--env-urls https://sergiopaniego-carla-env.hf.space \
|
||||
https://sergiopaniego-carla-env-2.hf.space \
|
||||
--model google/gemma-4-E2B-it
|
||||
```
|
||||
|
||||
Find the example [here](https://github.com/huggingface/huggingface-gemma-recipes/blob/main/scripts/carla_vlm_gemma.py).
|
||||
|
||||
### Fine-tuning with TRL on Vertex AI
|
||||
|
||||
Additionally, we have prepared an example on how to fine-tune Gemma 4 with TRL on Vertex AI using SFT, to showcase how to extend the function calling capabilities, whilst freezing both the vision and audio towers. The examples include how to build a custom Docker container with latest Transformers, TRL, etc. with CUDA support on Google Cloud, and how to run it via Vertex AI Serverless Training Jobs.
|
||||
|
||||
```python
|
||||
# pip install google-cloud-aiplatform --upgrade --quiet
|
||||
from google.cloud import aiplatform
|
||||
|
||||
aiplatform.init(
|
||||
project="<PROJECT_ID>",
|
||||
location="<LOCATION>",
|
||||
staging_bucket="<BUCKET_URI>",
|
||||
)
|
||||
|
||||
job = aiplatform.CustomContainerTrainingJob(
|
||||
display_name="gemma-4-fine-tuning",
|
||||
container_uri="<CONTAINER_URI>",
|
||||
command=["python", "/gcs/gemma-4-fine-tuning/train.py"],
|
||||
)
|
||||
|
||||
job = job.submit(
|
||||
replica_count=1,
|
||||
machine_type="a3-highgpu-1g",
|
||||
accelerator_type="NVIDIA_H100_80GB",
|
||||
accelerator_count=1,
|
||||
base_output_dir="<BUCKET_URI>/output-dir",
|
||||
environment_variables={
|
||||
"MODEL_ID": "google/gemma-4-E2B-it",
|
||||
"HF_TOKEN": <HF_TOKEN>,
|
||||
},
|
||||
boot_disk_size_gb=500,
|
||||
)
|
||||
```
|
||||
|
||||
You can find the complete example in the "Hugging Face on Google Cloud" docs at https://hf.co/docs/google-cloud/examples/vertex-ai-notebooks-fine-tune-gemma-4.
|
||||
|
||||
## Fine-tuning with Unsloth Studio
|
||||
|
||||
If you want to fine tune and run a Gemma 4 model in a UI, try out [Unsloth Studio](https://unsloth.ai/docs/new/studio). It runs locally or on Google Colab. First, install and start the app:
|
||||
|
||||
```shell
|
||||
# install unsloth studio on MacOS, Linux, WSL
|
||||
curl -fsSL https://unsloth.ai/install.sh | sh
|
||||
|
||||
# install unsloth studio on Windows
|
||||
irm https://unsloth.ai/install.ps1 | iex
|
||||
|
||||
# launch unsloth studio
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
# Search for for a Gemma 4 model like google/gemma-4-E2B-it
|
||||
```
|
||||
|
||||
Then select any of the Gemma 4 models from the hub.
|
||||
|
||||

|
||||
|
||||
## Try Gemma 4
|
||||
|
||||
We have shipped demos for you to try different Gemma 4 models. We include demos based on transformers implementation for [E4B](https://huggingface.co/spaces/huggingface-projects/gemma-4-e4b-it), [26B/A4B](https://huggingface.co/spaces/huggingface-projects/gemma-4-26b-a4b-it), and dense [31B](https://huggingface.co/spaces/huggingface-projects/gemma-4-31b-it) models, as well as a [WebGPU](https://huggingface.co/spaces/webml-community/Gemma-4-WebGPU) demo with transformers.js 🚀
|
||||
|
||||
|
||||
<iframe width="560" height="315" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/webgpu_demo.mp4" title="WebGPU Demo" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen> </iframe>
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
Gemma 4 models demonstrate exceptional performance across diverse benchmarks, from reasoning and coding to vision and long-context tasks. The graph below shows model performance vs size, with Gemma 4 models forming an impressive Pareto frontier:
|
||||
|
||||
<div style="display: flex; gap: 20px; justify-content: center; align-items: flex-start; flex-wrap: wrap;">
|
||||
<figure style="flex: 1; min-width: 300px; text-align: center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/g4_graph.png" alt="Gemma 4 Performance vs Size" style="width: 100%; height: 400px; object-fit: contain;">
|
||||
</figure>
|
||||
<figure style="flex: 1; min-width: 300px; text-align: center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/g4_graph_2.png" alt="Gemma 4 Arena Elo Score Comparison" style="width: 100%; height: 400px; object-fit: contain;">
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
<p style="text-align: center; font-size: 0.9em; color: #666;">Source: Google (<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/">blog.google</a>)</p>
|
||||
|
||||
Here are detailed benchmark results for the instruction-tuned models:
|
||||
|
||||
| Benchmark | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
|-----------|-------------|-----------------|-------------|-------------|------------------------|
|
||||
| **Reasoning & Knowledge** |
|
||||
| MMLU Pro | [85.2%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-31B-it) | [82.6%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-26B-A4B-it) | [69.4%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-E4B-it) | [60.0%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-E2B-it) | 67.6% |
|
||||
| AIME 2026 no tools | [89.2%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-31B-it) | [88.3%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-26B-A4B-it) | [42.5%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-E4B-it) | [37.5%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-E2B-it) | 20.8% |
|
||||
| GPQA Diamond | [84.3%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-31B-it) | [82.3%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-26B-A4B-it) | [58.6%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-E4B-it) | [43.4%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-E2B-it) | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Coding** |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| HLE no tools | [19.5%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-31B-it) | [8.7%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-26B-A4B-it) | - | - | - |
|
||||
| HLE with search | [26.5%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-31B-it) | [17.2%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-26B-A4B-it) | - | - | - |
|
||||
| **Vision** |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (edit distance) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
Landing Gemma-4 in the open-source ecosystem took a lot of effort from many people and not only the authors of this blog post. In no particular order, we thank many people from the open-source team: Gemma 4 transformers integration is owed to Cyril, Raushan, Eustache, Arthur, Lysandre. We thank Joshua for transformers.js integration and demo, Eric for mistral.rs integration, Son for Llama.cpp, Prince for MLX integration, Quentin, Albert and Kashif for TRL, Adarsh for SGLang transformers backend and Toshihiro for building the demos.
|
||||
This work wouldn't have been possible without Google's extensive contribution with the model artefact, but also the significant effort contributing the model to transformers in an effort to standardize it. The open-source ecosystem is now more complete, with a very capable, freely-licensed, open-source model.
|
||||
The Gemma 4 transformers integration was handled by Cyril, Raushan, Eustache, Arthur, Lysandre. We thank Joshua for the transformers.js integration and demo, Eric for mistral.rs integration, Son for Llama.cpp, Prince for MLX, Quentin for TRL, Adarsh for SGLang transformers backend, and Toshihiro for building several demos.
|
||||
|
||||
This work wouldn't have been possible without Google's extensive contribution with the model artefact, but also their significant effort contributing the model to transformers in an effort to standardize it. The open-source ecosystem is now more complete, with a very capable, freely-licensed, open-source model.
|
||||
@@ -0,0 +1,53 @@
|
||||
# Ollama: Importing a LoRA/QLoRA Adapter (Gemma 4 applicable)
|
||||
|
||||
Source: https://docs.ollama.com/import (fetched 2026-04-18)
|
||||
|
||||
## Modelfile syntax
|
||||
|
||||
**Safetensors adapter (merged or unmerged):**
|
||||
```dockerfile
|
||||
FROM <base model name>
|
||||
ADAPTER /path/to/safetensors/adapter/directory
|
||||
```
|
||||
|
||||
**GGUF adapter:**
|
||||
```dockerfile
|
||||
FROM <base model name>
|
||||
ADAPTER /path/to/file.gguf
|
||||
```
|
||||
|
||||
## Creation
|
||||
```shell
|
||||
ollama create my-model
|
||||
```
|
||||
|
||||
## Critical notes
|
||||
|
||||
- **The `FROM` base model MUST match the base the adapter was trained on** or you'll get erratic results. For Gemma 4: `FROM gemma4:e4b-it-q8_0` (or whichever base was used).
|
||||
- **Non-QLoRA adapters preferred.** Quantized adapter recipes (QLoRA) sometimes diverge in method across frameworks; using a straight LoRA adapter is more portable.
|
||||
- Gemma 4 is NOT explicitly listed in the Ollama docs' "supported architectures" section (which lists Llama 2/3, Mistral, Gemma 1/2) — but llama.cpp gained Gemma 4 support day one, and the Ollama gemma4 models work. Expect smooth sailing for text; vision adapters are a grey area.
|
||||
|
||||
## Converting a PEFT / Unsloth adapter to GGUF
|
||||
|
||||
Use llama.cpp's `convert_lora_to_gguf.py`:
|
||||
```bash
|
||||
python llama.cpp/convert_lora_to_gguf.py \
|
||||
--outfile gemma4-mortdecai-adapter.gguf \
|
||||
path/to/peft/adapter_dir
|
||||
```
|
||||
Or use HuggingFace's "GGUF-my-LoRA" Space: https://huggingface.co/spaces/ggml-org/gguf-my-lora (web UI).
|
||||
|
||||
## Unsloth fast path
|
||||
|
||||
Unsloth's notebooks finish with a cell that does exactly:
|
||||
```python
|
||||
model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
|
||||
```
|
||||
which produces a GGUF suitable for direct `ollama create`.
|
||||
|
||||
## Workflow for Seth's homelab
|
||||
|
||||
1. Fine-tune with Unsloth on a rented H100/H200 (or local 3090 for E4B).
|
||||
2. `model.save_pretrained_merged("merged_out", tokenizer, save_method = "merged_16bit")` — save the merged model in 16-bit safetensors.
|
||||
3. Use llama.cpp's `convert_hf_to_gguf.py` to make a GGUF, then quantize to Q4_K_M.
|
||||
4. Write a Modelfile pointing at the GGUF, `ollama create mortdecai-gemma4:v1 -f Modelfile`, push to local Ollama (pve197 CT 105 or steel141).
|
||||
@@ -0,0 +1,190 @@
|
||||
# Recommended Gemma 4 Fine-Tuning Recipe (Seth's Homelab)
|
||||
|
||||
## TL;DR
|
||||
|
||||
**Use Unsloth. Rent a single H100 on Vast.ai. Fine-tune Gemma 4 E4B (or 31B QLoRA). Save GGUF. `ollama create` back to CT 105.**
|
||||
|
||||
Why not the alternatives:
|
||||
- **Your 3090 Ti(s):** can handle E2B/E4B LoRA comfortably, but 26B A4B LoRA wants ~40 GB and 31B QLoRA wants 22 GB (fits, tightly). Axolotl's 5090-validated configs need Flex Attention to fit, and you lose half the throughput. An H100 at $2-3/hr for 3-4 hours is cheaper than the time you'll spend tuning memory.
|
||||
- **Axolotl** is great — in particular the 26B MoE ScatterMoE+expert-LoRA config is genuinely novel and Unsloth doesn't match it. But Axolotl has more moving parts (FSDP, kernels, flex attention), breaks more subtly on config errors, and the docs are less Gemma-4-specific than Unsloth's.
|
||||
- **TRL** has no Gemma-4-specific SFT script yet — you'd be porting `sft_gemma3.py`. Useful if you need DPO/GRPO or multimodal tool-call GRPO (the CARLA recipe), but heavier lift than Unsloth for plain SFT.
|
||||
- **Google cookbook** works and is authoritative but is slower than Unsloth (no fused kernels) and the notebook format is noisier to modify.
|
||||
|
||||
## Exact command
|
||||
|
||||
### On a rented H100 (Vast.ai `vast-h100` alias, already configured)
|
||||
|
||||
```bash
|
||||
ssh vast-h100
|
||||
# one-time setup
|
||||
pip install unsloth "trl==0.22.2" "transformers>=5.5.0" timm torchcodec
|
||||
```
|
||||
|
||||
Training script (save as `finetune_gemma4.py` on the H100):
|
||||
|
||||
```python
|
||||
from unsloth import FastModel
|
||||
from unsloth.chat_templates import get_chat_template, standardize_data_formats, train_on_responses_only
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
MODEL = "unsloth/gemma-4-E4B-it" # swap to "unsloth/gemma-4-31B-it" if you want more headroom
|
||||
DATASET = "YOUR_DATASET_HERE" # e.g. a mortdecai-style chat JSONL on HF Hub
|
||||
|
||||
# 1. Load model + tokenizer in 4-bit
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = MODEL,
|
||||
max_seq_length = 4096,
|
||||
load_in_4bit = True,
|
||||
full_finetuning = False,
|
||||
)
|
||||
|
||||
# 2. Attach LoRA
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = False, # text-only FT
|
||||
finetune_language_layers = True,
|
||||
finetune_attention_modules = True,
|
||||
finetune_mlp_modules = True,
|
||||
r = 16,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
# 3. Chat template — "gemma-4" (literal, with dash)
|
||||
tokenizer = get_chat_template(tokenizer, chat_template = "gemma-4")
|
||||
|
||||
# 4. Dataset: expects ShareGPT-style `conversations` field with {from, value}
|
||||
# OR OpenAI-style `messages` with {role, content} — standardize_data_formats handles both.
|
||||
dataset = load_dataset(DATASET, split = "train")
|
||||
dataset = standardize_data_formats(dataset)
|
||||
|
||||
def fmt(examples):
|
||||
convos = examples["conversations"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=False)
|
||||
.removeprefix('<bos>') # critical: avoid double <bos>
|
||||
for c in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
dataset = dataset.map(fmt, batched=True)
|
||||
|
||||
# 5. Train
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset,
|
||||
args = SFTConfig(
|
||||
dataset_text_field = "text",
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_steps = 10,
|
||||
num_train_epochs = 1,
|
||||
learning_rate = 2e-4,
|
||||
logging_steps = 1,
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
report_to = "none",
|
||||
output_dir = "outputs",
|
||||
),
|
||||
)
|
||||
|
||||
# 6. Mask everything except assistant turns
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|turn>user\n",
|
||||
response_part = "<|turn>model\n",
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# 7. Save merged 16-bit for GGUF conversion
|
||||
model.save_pretrained_merged("merged_out", tokenizer, save_method = "merged_16bit")
|
||||
|
||||
# 8. OR save directly to GGUF (Q4_K_M) — Ollama-ready
|
||||
model.save_pretrained_gguf("gemma4-mortdecai-v1", tokenizer, quantization_method = "q4_k_m")
|
||||
```
|
||||
|
||||
Run:
|
||||
```bash
|
||||
python finetune_gemma4.py
|
||||
```
|
||||
|
||||
### Pulling the result back and serving on CT 105
|
||||
|
||||
```bash
|
||||
# On the Vast box, upload to HF Hub or scp back:
|
||||
scp -r vast-h100:~/gemma4-mortdecai-v1*.gguf steel141:/tmp/
|
||||
|
||||
# On CT 105 (pve197 Ollama):
|
||||
cat > Modelfile <<'EOF'
|
||||
FROM /path/to/gemma4-mortdecai-v1.Q4_K_M.gguf
|
||||
PARAMETER num_ctx 8192
|
||||
PARAMETER temperature 1.0
|
||||
PARAMETER top_p 0.95
|
||||
PARAMETER top_k 64
|
||||
SYSTEM "You are Mortdecai, a Minecraft ops AI. You are powered by Gemma 4."
|
||||
EOF
|
||||
ollama create mortdecai-gemma4:v1 -f Modelfile
|
||||
ollama run mortdecai-gemma4:v1
|
||||
```
|
||||
|
||||
## Hardware sizing guide (from Unsloth's verified numbers)
|
||||
|
||||
| Variant | LoRA | QLoRA | Full FT | My recommendation |
|
||||
|---------|------|-------|---------|-------------------|
|
||||
| E2B | 8-10 GB | 8 GB | ~20 GB | Free Colab T4; local 3090 Ti fine |
|
||||
| E4B | 17 GB | 10 GB | ~32 GB | Local 3090 Ti (24 GB) tight but fine; H100 faster |
|
||||
| 26B A4B | >40 GB (16-bit recommended, NOT 4-bit) | not recommended | — | H100 80 GB |
|
||||
| 31B dense | >48 GB | 22 GB | 2×H100 | H100 80 GB or 2×3090 Ti FSDP |
|
||||
|
||||
For **Mortdecai-style behavior tuning** (matches your existing qwen-based setup), start with **E4B**. It's the sweet spot: larger than qwen3 8B in the things that matter (Gemma 4 E4B beats Gemma 3 27B on most benchmarks), vision-capable if you want it, and fits on a single 3090 Ti locally.
|
||||
|
||||
For a **real coding/reasoning upgrade**, use **31B QLoRA on H100**. Unsloth's 31B QLoRA notebook is the canonical recipe there.
|
||||
|
||||
## Gemma-4-specific pitfalls to NOT miss
|
||||
|
||||
1. **New chat template.** Gemma 4 uses `<|turn>user\n … <turn|>` — NOT Gemma 3's `<start_of_turn>user\n … <end_of_turn>`. Unsloth's `get_chat_template(tokenizer, chat_template="gemma-4")` handles this; the HF tokenizer's built-in Jinja also handles it if you rely on `apply_chat_template`. Axolotl uses `chat_template: gemma4` (no dash — different key).
|
||||
|
||||
2. **6 new tool-calling tokens.** `<|tool>`, `<tool|>`, `<|tool_call>`, `<tool_call|>`, `<|tool_response>`, `<tool_response|>`, plus the string-delimiter `<|"|>`. If fine-tuning on tool-call data, include full `<|tool_call>call:fn_name{args}<tool_call|>` in the assistant turn — no `role="tool"` branch exists.
|
||||
|
||||
3. **`modules_to_save=["lm_head","embed_tokens"]` + `ensure_weight_tying=True`** in LoraConfig if going vanilla PEFT (Google's cookbook does this explicitly). The new special tokens are *learned embeddings* — if the embed table is frozen, the adapter sees random vectors for them and training silently underperforms. Unsloth and Axolotl bake this in.
|
||||
|
||||
4. **Freeze the vision/audio tower by default.** Two idioms in the wild:
|
||||
- Axolotl: `freeze_mm_modules: true` + text-only LoRA regex.
|
||||
- HF's CARLA example: `target_modules="all-linear"` + `exclude_modules=["vision_tower", "multi_modal_projector"]`.
|
||||
Only train the vision tower if your task specifically needs the encoder to adapt (new image domain). For text-mode fine-tunes like Mortdecai, always freeze.
|
||||
|
||||
5. **Flash Attention DOES NOT WORK on Gemma 4.** FA2's max `head_dim=256`, FA4's is 128; Gemma 4's `global_head_dim=512` exceeds both. **Use SDP or Flex Attention.** Axolotl's configs set `sdp_attention: true`. TRL's `sft_gemma3.py` uses `attn_implementation="eager"` — this works but is slow; prefer `"sdpa"`. (Unsloth's FastModel handles this automatically.)
|
||||
|
||||
6. **LoRA kernels OFF.** Gemma 4's shared-KV-cache layers break the fused LoRA kernels. Axolotl sets `lora_mlp_kernel/qkv_kernel/o_kernel: false` explicitly. Unsloth's `FastModel` is fine because it uses its own kernel path that knows about shared-KV.
|
||||
|
||||
7. **Don't prepend a second `<bos>`.** `apply_chat_template` adds one; SFTTrainer's collator adds one; if you don't `.removeprefix('<bos>')` before passing text to the trainer, you train the model to expect `<bos><bos>`. Unsloth's example notebooks do this strip — copy their pattern.
|
||||
|
||||
8. **26B A4B: use 16-bit LoRA, not QLoRA.** Unsloth's docs explicitly say "MoE QLoRA not recommended, dense 31B is fine." Axolotl has a ScatterMoE+expert-quantized+expert-LoRA config that does make 4-bit work for the MoE (validated on a 5090), but it's the only tool that does — Unsloth's 26B A4B notebook goes 16-bit for quality.
|
||||
|
||||
9. **Initial training loss of 13-15 on E2B/E4B is normal, not a bug.** Multimodal models start much higher than 5-8. If you see 13-15 don't panic — GOTCHAS.md §"Fine-Tuning Ecosystem Issues" has this.
|
||||
|
||||
10. **`mm_token_type_ids` required during training even for text-only data.** Day-one PEFT/Transformers bug: the multimodal collator requires this field. Pin `transformers>=5.5.0` and `peft>=0.15` to ensure the fix is present.
|
||||
|
||||
## Feature parity snapshot (2026-04-18)
|
||||
|
||||
| Feature | Unsloth | TRL | Axolotl | Google cookbook |
|
||||
|---------|:-:|:-:|:-:|:-:|
|
||||
| Text SFT | ✓ | ~ (via gemma3 script, change model_id) | ✓ | ✓ |
|
||||
| Vision SFT | ✓ | ~ (via sft_vlm_gemma3) | ✓ (E2B) | ✓ |
|
||||
| Audio SFT | ✓ (E2B/E4B) | ✗ | ✗ | ✗ |
|
||||
| GRPO | ✓ (E2B + RL game notebooks) | ✓ (CARLA VLM-GRPO, official) | ✗ | ✗ |
|
||||
| DPO | via TRL | ✓ | ✓ | ✗ |
|
||||
| 26B MoE native | ✓ (16-bit LoRA) | ~ | ✓ (ScatterMoE + expert-LoRA, validated on 5090) | ✗ |
|
||||
| 31B dense QLoRA | ✓ | ~ | ✓ (with Flex Attn) | ~ |
|
||||
| Free Colab T4 path | ✓ (E2B) | ✗ | ✗ | ~ (via Colab Pro) |
|
||||
| Multi-GPU FSDP | ~ | ✓ | ✓ (first-class) | ~ |
|
||||
|
||||
**Bottom line:** Unsloth has the broadest Gemma-4-native coverage (including audio and RL games, which no one else has). Axolotl has the best 26B MoE story. TRL has the best multimodal-RL story (CARLA). Google cookbook is the reference, not the fast path.
|
||||
|
||||
For Seth's stated use case (fine-tune like mortdecai), Unsloth wins on ergonomics + speed + T4 free-tier fallback.
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
###############################################################################################
|
||||
# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py #
|
||||
###############################################################################################
|
||||
@@ -0,0 +1,320 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
# Full training
|
||||
```
|
||||
python examples/scripts/grpo_agent.py \
|
||||
--model_name_or_path Qwen/Qwen3-1.7B \
|
||||
--output_dir grpo_biogrid_qwen_3g-1.7b \
|
||||
--push_to_hub True \
|
||||
--use_vllm True \
|
||||
--vllm_mode colocate \
|
||||
--max_completion_length 1024 \
|
||||
--report_to trackio \
|
||||
--log_completions True \
|
||||
--max_steps 400
|
||||
```
|
||||
"""
|
||||
|
||||
import re
|
||||
import signal
|
||||
import sqlite3
|
||||
import textwrap
|
||||
from contextlib import contextmanager
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser
|
||||
|
||||
|
||||
def query_reward(completions, answer, **kwargs):
|
||||
"""
|
||||
Reward query strategy:
|
||||
- Penalize more than 2 queries
|
||||
- Penalize generic queries (LIMIT 1 / PRAGMA)
|
||||
- Reward usage of WHERE
|
||||
- Reward evidence supporting the final answer
|
||||
"""
|
||||
rewards = []
|
||||
|
||||
for completion, ans in zip(completions, answer, strict=False):
|
||||
reward = 0.0
|
||||
sql_queries = []
|
||||
tool_results = []
|
||||
|
||||
# collect all SQL queries and tool results
|
||||
for turn in completion:
|
||||
if turn.get("tool_calls"):
|
||||
for call in turn["tool_calls"]:
|
||||
sql = call["function"]["arguments"].get("sql_command", "").lower()
|
||||
sql_queries.append(sql)
|
||||
if turn.get("role") == "tool" and turn.get("content"):
|
||||
tool_results.append(turn["content"])
|
||||
|
||||
# --- penalize too many queries ---
|
||||
if len(sql_queries) > 3:
|
||||
reward -= 1.5
|
||||
|
||||
# --- check query quality ---
|
||||
where_count = 0
|
||||
for q in sql_queries:
|
||||
if "limit 1" in q:
|
||||
reward -= 1.0
|
||||
if " where " not in q:
|
||||
reward -= 0.5
|
||||
else:
|
||||
where_count += 1
|
||||
reward += min(where_count, 3) * 0.4 # small bonus for WHERE usage
|
||||
|
||||
# --- evidence check: do queries support the answer? ---
|
||||
combined_results = []
|
||||
error_detected = False
|
||||
|
||||
for res in tool_results:
|
||||
if isinstance(res, dict) and "error" in res:
|
||||
error_detected = True
|
||||
elif isinstance(res, list):
|
||||
combined_results.extend(res)
|
||||
|
||||
# if error detected, penalize heavily
|
||||
if error_detected:
|
||||
reward -= 2.0
|
||||
elif len(sql_queries) == 0:
|
||||
reward -= 1.5
|
||||
else:
|
||||
has_hits = len(combined_results) > 0
|
||||
correct_answer = ans.lower()
|
||||
if (has_hits and correct_answer == "yes") or (not has_hits and correct_answer == "no"):
|
||||
reward += 2.0
|
||||
else:
|
||||
reward -= 1.5
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def correctness_reward(completions, answer, **kwargs):
|
||||
"""
|
||||
Reward Yes/No correctness.
|
||||
Model must provide final answer enclosed in stars — *yes* or *no*.
|
||||
Does not reward informal yes/no buried in text.
|
||||
"""
|
||||
rewards = []
|
||||
for completion, ans in zip(completions, answer, strict=False):
|
||||
raw = completion[-1]["content"].lower()
|
||||
|
||||
# detect form *yes* or *no*
|
||||
match = re.search(r"\*(yes|no)\*", raw)
|
||||
guess = match.group(1) if match else None
|
||||
|
||||
reward = 0.0
|
||||
|
||||
if guess is None:
|
||||
reward -= 0.5 # invalid format
|
||||
elif guess == ans.lower():
|
||||
reward += 0.6 # correct under required format
|
||||
else:
|
||||
reward -= 1.0 # wrong answer
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def structure_reward(completions, **kwargs):
|
||||
"""
|
||||
Reward proper assistant structure.
|
||||
Encourages a logical sequence: tool call + response + optional extra content.
|
||||
"""
|
||||
rewards = []
|
||||
|
||||
for completion in completions:
|
||||
has_call = False
|
||||
has_response = False
|
||||
has_other = False
|
||||
|
||||
for turn in completion:
|
||||
role = turn.get("role")
|
||||
if role == "assistant" and turn.get("tool_calls"):
|
||||
has_call = True
|
||||
elif role == "tool":
|
||||
has_response = True
|
||||
else:
|
||||
content = turn.get("content")
|
||||
if content and content.strip() not in ["", "<think>"]:
|
||||
has_other = True
|
||||
|
||||
# Reward sequences
|
||||
if has_call and has_response:
|
||||
if has_other:
|
||||
reward = 0.1
|
||||
else:
|
||||
reward = 0.05 # still positive even without extra text
|
||||
elif has_call and not has_response:
|
||||
reward = -0.15
|
||||
else:
|
||||
reward = 0.0 # neutral if no call
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Database tool function
|
||||
# ------------------------
|
||||
class TimeoutError(Exception):
|
||||
"""Raised when a function call times out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timeout(seconds):
|
||||
"""Context manager that raises TimeoutError if execution exceeds time limit."""
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError(f"Operation timed out after {seconds} seconds")
|
||||
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(seconds)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def query_biogrid(sql_command: str) -> list[tuple]:
|
||||
"""
|
||||
Execute a read-only SQL command on the BioGRID database.
|
||||
|
||||
BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.
|
||||
|
||||
Args:
|
||||
sql_command: The SQL command to execute.
|
||||
|
||||
Returns:
|
||||
A list of tuples containing the query results.
|
||||
"""
|
||||
with timeout(5):
|
||||
conn = sqlite3.connect("file:biogrid.db?mode=ro", uri=True)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql_command)
|
||||
results = cursor.fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Dataset formatting
|
||||
# ------------------------
|
||||
def format_example(example):
|
||||
question = example["question"]
|
||||
preamble = textwrap.dedent("""\
|
||||
You have access to the BioGRID SQLite database.
|
||||
Use SQL queries to retrieve only the information needed to answer the question.
|
||||
|
||||
Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,
|
||||
and each entry can contain multiple gene names or synonyms separated by '|', for example:
|
||||
'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'
|
||||
So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.
|
||||
|
||||
If the database schema is unclear or you are unsure about column names:
|
||||
- First inspect the schema with `PRAGMA table_info(interactions);`
|
||||
- Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`
|
||||
|
||||
Otherwise, directly query the required data.
|
||||
|
||||
Final answer must be enclosed in stars, e.g. *Yes* or *No*.
|
||||
Facts:
|
||||
- The NCBI Taxonomy identifier for humans is taxid:9606.
|
||||
""")
|
||||
content = f"{preamble}\nQuestion: {question}"
|
||||
prompt = [{"role": "user", "content": content}]
|
||||
return {"prompt": prompt}
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Main
|
||||
# ------------------------
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
|
||||
# ------------------------
|
||||
# Create DB
|
||||
# ------------------------
|
||||
print("Creating biogrid.db...")
|
||||
# Load dataset
|
||||
biogrid_dataset = load_dataset("qgallouedec/biogrid", split="train")
|
||||
df = biogrid_dataset.to_pandas()
|
||||
|
||||
# Normalize column names: remove spaces, replace with underscores
|
||||
df.columns = [c.replace(" ", "_") for c in df.columns]
|
||||
conn = sqlite3.connect("biogrid.db")
|
||||
try:
|
||||
df.to_sql("interactions", conn, if_exists="replace", index=False)
|
||||
print(f"biogrid.db created. Rows stored: {len(df)}")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------
|
||||
# Load and format dataset
|
||||
# ------------------------
|
||||
dataset = load_dataset("qgallouedec/biogrid_qa", split="train")
|
||||
dataset = dataset.filter(
|
||||
lambda example: example["question"].startswith("Does the gene ")
|
||||
) # keep only simple questions for example
|
||||
dataset = dataset.map(format_example, remove_columns=["question"])
|
||||
|
||||
train_dataset = dataset
|
||||
eval_dataset = None # No eval by default, can be added if needed
|
||||
|
||||
training_args.chat_template_kwargs = {"enable_thinking": False}
|
||||
|
||||
# ------------------------
|
||||
# Initialize trainer
|
||||
# ------------------------
|
||||
trainer = GRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tools=[query_biogrid],
|
||||
reward_funcs=[correctness_reward, structure_reward, query_reward],
|
||||
args=training_args,
|
||||
)
|
||||
|
||||
# ------------------------
|
||||
# Train
|
||||
# ------------------------
|
||||
trainer.train()
|
||||
|
||||
# ------------------------
|
||||
# Save and push
|
||||
# ------------------------
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "Pillow",
|
||||
# "math-verify",
|
||||
# "latex2sympy2_extended",
|
||||
# "torchvision",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
pip install math_verify
|
||||
|
||||
# For Qwen/Qwen2.5-VL-3B-Instruct
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/grpo_vlm.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--output_dir grpo-Qwen2.5-VL-3B-Instruct \
|
||||
--learning_rate 1e-5 \
|
||||
--dtype bfloat16 \
|
||||
--max_completion_length 1024 \
|
||||
--use_vllm \
|
||||
--vllm_mode colocate \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions
|
||||
|
||||
# For HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
pip install num2words==0.5.14
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/grpo_vlm.py \
|
||||
--model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
|
||||
--output_dir grpo-SmolVLM2-2.2B-Instruct \
|
||||
--learning_rate 1e-5 \
|
||||
--dtype bfloat16 \
|
||||
--max_completion_length 1024 \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_generations 2
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
GRPOTrainer,
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
################
|
||||
# Model
|
||||
################
|
||||
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
||||
training_args.model_init_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
|
||||
training_args.model_init_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train")
|
||||
dataset = dataset.train_test_split(test_size=100, seed=42)
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
|
||||
"assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
|
||||
"The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my "
|
||||
"reasoning.\n</think>\nThis is my answer."
|
||||
)
|
||||
|
||||
def make_conversation(example):
|
||||
prompt = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": example["problem"]},
|
||||
]
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
|
||||
# Filter have big images
|
||||
def filter_big_images(example):
|
||||
image = example["image"]
|
||||
return image.size[0] < 512 and image.size[1] < 512
|
||||
|
||||
dataset = dataset.filter(filter_big_images)
|
||||
|
||||
def convert_to_rgb(example):
|
||||
image = example["image"]
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
example["image"] = image
|
||||
return example
|
||||
|
||||
dataset = dataset.map(convert_to_rgb)
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = GRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
args=training_args,
|
||||
reward_funcs=[think_format_reward, accuracy_reward],
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
###############################################################################################
|
||||
# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py #
|
||||
###############################################################################################
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "Pillow",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Train Gemma-3 on the Codeforces COTS dataset.
|
||||
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py
|
||||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
|
||||
def main():
|
||||
# Load dataset
|
||||
train_dataset = load_dataset("open-r1/codeforces-cots", split="train")
|
||||
train_dataset = train_dataset.remove_columns("prompt")
|
||||
|
||||
# Load model
|
||||
model_id = "google/gemma-3-12b-it"
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")
|
||||
|
||||
# Train model
|
||||
training_args = SFTConfig(
|
||||
output_dir=f"{model_id}-codeforces-SFT",
|
||||
bf16=True,
|
||||
use_liger_kernel=True,
|
||||
max_length=8192,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=8,
|
||||
dataset_num_proc=32,
|
||||
num_train_epochs=1,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Push to hub
|
||||
trainer.push_to_hub(dataset_name="open-r1/codeforces-cots")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "bitsandbytes",
|
||||
# "liger-kernel",
|
||||
# "trackio",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Teach tool calling to CohereLabs/tiny-aya-global using SFT with QLoRA on the bebechien/SimpleToolCalling dataset.
|
||||
|
||||
The model used in this script does not have native tool-calling support. We extend its existing Jinja2 chat template to
|
||||
serialize tool schemas into the system preamble and render tool calls as structured <tool_call> XML inside the model's
|
||||
native <|START_RESPONSE|> / <|END_RESPONSE|> delimiters. The modified template is saved with the tokenizer, so
|
||||
inference only requires loading the tokenizer from the output directory and calling apply_chat_template with
|
||||
tools=TOOLS — no manual system-prompt construction needed.
|
||||
|
||||
Example:
|
||||
|
||||
python examples/scripts/sft_tiny_aya_tool_calling.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
|
||||
# These are the tool schemas that are used in the dataset
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_knowledge_base",
|
||||
"description": "Search internal company documents, policies and project data.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "query string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
"return": {"type": "string"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_google",
|
||||
"description": "Search public information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "query string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
"return": {"type": "string"},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def create_conversation(sample):
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": sample["user_content"]}],
|
||||
"completion": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sample["tool_name"],
|
||||
"arguments": json.loads(sample["tool_arguments"]),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
"tools": TOOLS,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
model_id = "CohereLabs/tiny-aya-global"
|
||||
dataset_name = "bebechien/SimpleToolCalling"
|
||||
output_dir = "tiny-aya-global-tool-calling-SFT"
|
||||
|
||||
# Load and format dataset
|
||||
dataset = load_dataset(dataset_name, split="train")
|
||||
dataset = dataset.map(create_conversation, remove_columns=dataset.features)
|
||||
dataset = dataset.train_test_split(test_size=0.5, shuffle=True)
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="sdpa",
|
||||
dtype=torch.float16,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
),
|
||||
)
|
||||
|
||||
# Configure LoRA
|
||||
peft_config = LoraConfig(
|
||||
r=32,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
)
|
||||
|
||||
# Train
|
||||
training_args = SFTConfig(
|
||||
output_dir=output_dir,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
# Use the tool-aware chat template
|
||||
chat_template_path=str(Path(__file__).parent / "tiny_aya_chat_template.jinja"),
|
||||
warmup_steps=5,
|
||||
learning_rate=2e-4,
|
||||
optim="paged_adamw_8bit",
|
||||
logging_steps=1,
|
||||
report_to="trackio",
|
||||
trackio_space_id=output_dir,
|
||||
max_length=1024,
|
||||
use_liger_kernel=True,
|
||||
activation_offloading=True,
|
||||
push_to_hub=True,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
peft_config=peft_config,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Save model and tokenizer (tokenizer carries the updated chat template)
|
||||
trainer.save_model(output_dir)
|
||||
trainer.push_to_hub(dataset_name=dataset_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "Pillow>=9.4.0",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
pip install pillow
|
||||
|
||||
# Tested on 8x H100 GPUs
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path llava-hf/llava-1.5-7b-hf \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--output_dir LLaVA-1.5-7B-SFT \
|
||||
--dtype bfloat16
|
||||
|
||||
For LLaVA-NeXT, use:
|
||||
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf
|
||||
|
||||
For meta-llama/Llama-3.2-11B-Vision-Instruct, use:
|
||||
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path HuggingFaceTB/SmolVLM-Instruct \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir SmolVLM-SFT \
|
||||
--dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.max_length = None
|
||||
|
||||
################
|
||||
# Model
|
||||
################
|
||||
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
model_kwargs["device_map"] = get_kbit_device_map()
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
@@ -0,0 +1,189 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "Pillow>=9.4.0",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Train Gemma 3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm_gemma3.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--output_dir Gemma-3-4B-SFT-MMIU \
|
||||
--dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
|
||||
Train Gemma 3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm_gemma3.py \
|
||||
--dataset_name FanqingM/MMIU-Benchmark \
|
||||
--dataset_train_split test \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--output_dir Gemma-3-4B-SFT-MMIU \
|
||||
--dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
for element in content:
|
||||
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
if image is not None:
|
||||
image = Image.open(io.BytesIO(image["bytes"]))
|
||||
image_inputs.append(image.convert("RGB"))
|
||||
return image_inputs
|
||||
|
||||
|
||||
def format_data(samples: dict[str, any]) -> dict[str, list]:
|
||||
formatted_samples = {"messages": []}
|
||||
for cont in range(len(samples["question"])):
|
||||
images = []
|
||||
for img_path in samples["input_image_path"][cont]:
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
images.append({"type": "image", "image": image})
|
||||
except Exception as e:
|
||||
print(f"Error processing image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
formatted_samples["messages"].append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
|
||||
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
|
||||
]
|
||||
)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict:
|
||||
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
||||
zip_files = [f for f in all_files if f.endswith(".zip")]
|
||||
|
||||
for zip_filename in zip_files:
|
||||
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
|
||||
extract_folder = zip_filename.replace(".zip", "")
|
||||
os.makedirs(extract_folder, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_folder)
|
||||
|
||||
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.max_length = None
|
||||
|
||||
################
|
||||
# Model
|
||||
################
|
||||
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
model_kwargs["device_map"] = get_kbit_device_map()
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
|
||||
dataset = prepare_dataset(dataset, script_args.dataset_name)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,156 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "peft",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
# Full training
|
||||
```
|
||||
python trl/scripts/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--eos_token '<|im_end|>' \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--output_dir Qwen2-0.5B-SFT \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
# LoRA
|
||||
```
|
||||
python trl/scripts/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--learning_rate 2.0e-4 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--eos_token '<|im_end|>' \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--use_peft \
|
||||
--lora_r 32 \
|
||||
--lora_alpha 16 \
|
||||
--output_dir Qwen2-0.5B-SFT \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args, dataset_args):
|
||||
from accelerate import logging
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
||||
|
||||
from trl import SFTTrainer, get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
################
|
||||
# Model init kwargs
|
||||
################
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=model_args.dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
model_kwargs["device_map"] = get_kbit_device_map()
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
# Create model
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
|
||||
|
||||
if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
|
||||
# Load the dataset
|
||||
if dataset_args.datasets and script_args.dataset_name:
|
||||
logger.warning(
|
||||
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
|
||||
"dataset and `dataset_name` will be ignored."
|
||||
)
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif dataset_args.datasets and not script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif not dataset_args.datasets and script_args.dataset_name:
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
|
||||
|
||||
# Initialize the SFT trainer
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Log training complete
|
||||
trainer.accelerator.print("✅ Training completed.")
|
||||
|
||||
# Save and push to Hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None, prog: str | None = None):
|
||||
from trl import DatasetMixtureConfig, ModelConfig, ScriptArguments, SFTConfig, TrlParser
|
||||
|
||||
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
|
||||
else:
|
||||
parser = TrlParser(dataclass_types, prog=prog)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
script_args, training_args, model_args, dataset_args = parser.parse_args_and_config(fail_with_unknown_args=False)
|
||||
main(script_args, training_args, model_args, dataset_args)
|
||||
@@ -0,0 +1,250 @@
|
||||
<h1 align="center" style="margin:0;">
|
||||
<a href="https://unsloth.ai/docs"><picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20WHITE%20LOGO.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20BLACK%20LOGO.png">
|
||||
<img alt="Unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20BLACK%20LOGO.png" height="60" style="max-width:100%;">
|
||||
</picture></a>
|
||||
</h1>
|
||||
<h3 align="center" style="margin: 0; margin-top: 0;">
|
||||
Run and train AI models with a unified local interface.
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
<a href="#-features">Features</a> •
|
||||
<a href="#-quickstart">Quickstart</a> •
|
||||
<a href="#-free-notebooks">Notebooks</a> •
|
||||
<a href="https://unsloth.ai/docs">Documentation</a> •
|
||||
<a href="https://www.reddit.com/r/unsloth/">Reddit</a>
|
||||
</p>
|
||||
<a href="https://unsloth.ai/docs/new/studio">
|
||||
<img alt="unsloth studio ui homepage" src="https://raw.githubusercontent.com/unslothai/unsloth/main/studio/frontend/public/studio%20github%20landscape%20colab%20display.png" style="max-width: 100%; margin-bottom: 0;"></a>
|
||||
|
||||
Unsloth Studio (Beta) lets you run and train text, [audio](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), [embedding](https://unsloth.ai/docs/new/embedding-finetuning), [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) models on Windows, Linux and macOS.
|
||||
|
||||
## ⭐ Features
|
||||
Unsloth provides several key features for both inference and training:
|
||||
### Inference
|
||||
* **Search + download + run models** including GGUF, LoRA adapters, safetensors
|
||||
* **Export models**: [Save or export](https://unsloth.ai/docs/new/studio/export) models to GGUF, 16-bit safetensors and other formats.
|
||||
* **Tool calling**: Support for [self-healing tool calling](https://unsloth.ai/docs/new/studio/chat#auto-healing-tool-calling) and web search
|
||||
* **[Code execution](https://unsloth.ai/docs/new/studio/chat#code-execution)**: lets LLMs test code in Claude artifacts and sandbox environments
|
||||
* [Auto-tune inference parameters](https://unsloth.ai/docs/new/studio/chat#auto-parameter-tuning) and customize chat templates.
|
||||
* We work directly with teams behind [gpt-oss](https://docs.unsloth.ai/new/gpt-oss-how-to-run-and-fine-tune#unsloth-fixes-for-gpt-oss), [Qwen3](https://www.reddit.com/r/LocalLLaMA/comments/1kaodxu/qwen3_unsloth_dynamic_ggufs_128k_context_bug_fixes/), [Llama 4](https://github.com/ggml-org/llama.cpp/pull/12889), [Mistral](models/tutorials/devstral-how-to-run-and-fine-tune.md), [Gemma 1-3](https://news.ycombinator.com/item?id=39671146), and [Phi-4](https://unsloth.ai/blog/phi4), where we’ve fixed bugs that improve model accuracy.
|
||||
* Upload images, audio, PDFs, code, DOCX and more file types to chat with.
|
||||
### Training
|
||||
* Train and RL **500+ models** up to **2x faster** with up to **70% less VRAM**, with no accuracy loss.
|
||||
* Custom Triton and mathematical **kernels**. See some collabs we did with [PyTorch](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) and [Hugging Face](https://unsloth.ai/docs/new/faster-moe).
|
||||
* **Data Recipes**: [Auto-create datasets](https://unsloth.ai/docs/new/studio/data-recipe) from **PDF, CSV, DOCX** etc. Edit data in a visual-node workflow.
|
||||
* **[Reinforcement Learning](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide)** (RL): The most efficient [RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) library, using **80% less VRAM** for GRPO, [FP8](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) etc.
|
||||
* Supports full fine-tuning, RL, pretraining, 4-bit, 16-bit and, FP8 training.
|
||||
* **Observability**: Monitor training live, track loss and GPU usage and customize graphs.
|
||||
* [Multi-GPU](https://unsloth.ai/docs/basics/multi-gpu-training-with-unsloth) training is supported, with major improvements coming soon.
|
||||
|
||||
## ⚡ Quickstart
|
||||
Unsloth can be used in two ways: through **[Unsloth Studio](https://unsloth.ai/docs/new/studio/)**, the web UI, or through **Unsloth Core**, the code-based version. Each has different requirements.
|
||||
|
||||
### Unsloth Studio (web UI)
|
||||
Unsloth Studio (Beta) works on **Windows, Linux, WSL** and **macOS**.
|
||||
|
||||
* **CPU:** Supported for Chat and Data Recipes currently
|
||||
* **NVIDIA:** Training works on RTX 30/40/50, Blackwell, DGX Spark, Station and more
|
||||
* **macOS:** Currently supports chat and Data Recipes. **MLX training** is coming very soon
|
||||
* **AMD:** Chat + Data works. Train with [Unsloth Core](#unsloth-core-code-based). Studio support is out soon.
|
||||
* **Coming soon:** Training support for Apple MLX, AMD, and Intel.
|
||||
* **Multi-GPU:** Available now, with a major upgrade on the way
|
||||
|
||||
#### macOS, Linux, WSL:
|
||||
```bash
|
||||
curl -fsSL https://unsloth.ai/install.sh | sh
|
||||
```
|
||||
#### Windows:
|
||||
```powershell
|
||||
irm https://unsloth.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
#### Launch
|
||||
```bash
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
```
|
||||
|
||||
#### Update
|
||||
To update, use the same install commands as above. Or run (does not work on Windows):
|
||||
```bash
|
||||
unsloth studio update
|
||||
```
|
||||
|
||||
#### Docker
|
||||
Use our [Docker image](https://hub.docker.com/r/unsloth/unsloth) ```unsloth/unsloth``` container. Run:
|
||||
```bash
|
||||
docker run -d -e JUPYTER_PASSWORD="mypassword" \
|
||||
-p 8888:8888 -p 8000:8000 -p 2222:22 \
|
||||
-v $(pwd)/work:/workspace/work \
|
||||
--gpus all \
|
||||
unsloth/unsloth
|
||||
```
|
||||
|
||||
#### Developer, Nightly, Uninstall
|
||||
To see developer, nightly and uninstallation etc. instructions, see [advanced installation](#-advanced-installation).
|
||||
|
||||
### Unsloth Core (code-based)
|
||||
#### Linux, WSL:
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv unsloth_env --python 3.13
|
||||
source unsloth_env/bin/activate
|
||||
uv pip install unsloth --torch-backend=auto
|
||||
```
|
||||
#### Windows:
|
||||
```powershell
|
||||
winget install -e --id Python.Python.3.13
|
||||
winget install --id=astral-sh.uv -e
|
||||
uv venv unsloth_env --python 3.13
|
||||
.\unsloth_env\Scripts\activate
|
||||
uv pip install unsloth --torch-backend=auto
|
||||
```
|
||||
For Windows, `pip install unsloth` works only if you have PyTorch installed. Read our [Windows Guide](https://unsloth.ai/docs/get-started/install/windows-installation).
|
||||
You can use the same Docker image as Unsloth Studio.
|
||||
|
||||
#### AMD, Intel:
|
||||
For RTX 50x, B200, 6000 GPUs: `uv pip install unsloth --torch-backend=auto`. Read our guides for: [Blackwell](https://unsloth.ai/docs/blog/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth) and [DGX Spark](https://unsloth.ai/docs/blog/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth). <br>
|
||||
To install Unsloth on **AMD** and **Intel** GPUs, follow our [AMD Guide](https://unsloth.ai/docs/get-started/install/amd) and [Intel Guide](https://unsloth.ai/docs/get-started/install/intel).
|
||||
|
||||
## 📒 Free Notebooks
|
||||
|
||||
Train for free with our notebooks. You can use our new [free Unsloth Studio notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb) to run and train models for free in a web UI.
|
||||
Read our [guide](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide). Add dataset, run, then deploy your trained model.
|
||||
|
||||
| Model | Free Notebooks | Performance | Memory use |
|
||||
|-----------|---------|--------|----------|
|
||||
| **Gemma 4 (E2B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma4_(E2B)-Vision.ipynb) | 1.5x faster | 50% less |
|
||||
| **Qwen3.5 (4B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_5_(4B)_Vision.ipynb) | 1.5x faster | 60% less |
|
||||
| **gpt-oss (20B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-Fine-tuning.ipynb) | 2x faster | 70% less |
|
||||
| **Qwen3.5 GSPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_5_(4B)_Vision_GRPO.ipynb) | 2x faster | 70% less |
|
||||
| **gpt-oss (20B): GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) | 2x faster | 80% less |
|
||||
| **Qwen3: Advanced GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb) | 2x faster | 70% less |
|
||||
| **embeddinggemma (300M)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/EmbeddingGemma_(300M).ipynb) | 2x faster | 20% less |
|
||||
| **Mistral Ministral 3 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Ministral_3_VL_(3B)_Vision.ipynb) | 1.5x faster | 60% less |
|
||||
| **Llama 3.1 (8B) Alpaca** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 70% less |
|
||||
| **Llama 3.2 Conversational** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 70% less |
|
||||
| **Orpheus-TTS (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Orpheus_(3B)-TTS.ipynb) | 1.5x faster | 50% less |
|
||||
|
||||
- See all our notebooks for: [Kaggle](https://github.com/unslothai/notebooks?tab=readme-ov-file#-kaggle-notebooks), [GRPO](https://unsloth.ai/docs/get-started/unsloth-notebooks#grpo-reasoning-rl-notebooks), [TTS](https://unsloth.ai/docs/get-started/unsloth-notebooks#text-to-speech-tts-notebooks), [embedding](https://unsloth.ai/docs/new/embedding-finetuning) & [Vision](https://unsloth.ai/docs/get-started/unsloth-notebooks#vision-multimodal-notebooks)
|
||||
- See [all our models](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [all our notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks)
|
||||
- See detailed documentation for Unsloth [here](https://unsloth.ai/docs)
|
||||
|
||||
## 🦥 Unsloth News
|
||||
- **Gemma 4**: Run and train Google’s new models directly in Unsloth Studio! [Blog](https://unsloth.ai/docs/models/gemma-4)
|
||||
- **Introducing Unsloth Studio**: our new web UI for running and training LLMs. [Blog](https://unsloth.ai/docs/new/studio)
|
||||
- **Qwen3.5** - 0.8B, 2B, 4B, 9B, 27B, 35-A3B, 112B-A10B are now supported. [Guide + notebooks](https://unsloth.ai/docs/models/qwen3.5/fine-tune)
|
||||
- Train **MoE LLMs 12x faster** with 35% less VRAM - DeepSeek, GLM, Qwen and gpt-oss. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
- **Embedding models**: Unsloth now supports ~1.8-3.3x faster embedding fine-tuning. [Blog](https://unsloth.ai/docs/new/embedding-finetuning) • [Notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks#embedding-models)
|
||||
- New **7x longer context RL** vs. all other setups, via our new batching algorithms. [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
- New RoPE & MLP **Triton Kernels** & **Padding Free + Packing**: 3x faster training & 30% less VRAM. [Blog](https://unsloth.ai/docs/new/3x-faster-training-packing)
|
||||
- **500K Context**: Training a 20B model with >500K context is now possible on an 80GB GPU. [Blog](https://unsloth.ai/docs/blog/500k-context-length-fine-tuning)
|
||||
- **FP8 & Vision RL**: You can now do FP8 & VLM GRPO on consumer GPUs. [FP8 Blog](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/vision-reinforcement-learning-vlm-rl)
|
||||
- **gpt-oss** by OpenAI: Read our [RL blog](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/gpt-oss-reinforcement-learning), [Flex Attention](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training) blog and [Guide](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune).
|
||||
|
||||
## 📥 Advanced Installation
|
||||
The below advanced instructions are for Unsloth Studio. For Unsloth Core advanced installation, [view our docs](https://unsloth.ai/docs/get-started/install/pip-install#advanced-pip-installation).
|
||||
#### Developer installs: macOS, Linux, WSL:
|
||||
```bash
|
||||
git clone https://github.com/unslothai/unsloth
|
||||
cd unsloth
|
||||
./install.sh --local
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
```
|
||||
Then to update :
|
||||
```bash
|
||||
unsloth studio update
|
||||
```
|
||||
|
||||
#### Developer installs: Windows PowerShell:
|
||||
```powershell
|
||||
git clone https://github.com/unslothai/unsloth.git
|
||||
cd unsloth
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
.\install.ps1 --local
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
```
|
||||
Then to update :
|
||||
```bash
|
||||
unsloth studio update
|
||||
```
|
||||
|
||||
#### Nightly: MacOS, Linux, WSL:
|
||||
```bash
|
||||
git clone https://github.com/unslothai/unsloth
|
||||
cd unsloth
|
||||
git checkout nightly
|
||||
./install.sh --local
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
```
|
||||
Then to launch every time:
|
||||
```bash
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
```
|
||||
|
||||
#### Nightly: Windows:
|
||||
Run in Windows Powershell:
|
||||
```bash
|
||||
git clone https://github.com/unslothai/unsloth.git
|
||||
cd unsloth
|
||||
git checkout nightly
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
.\install.ps1 --local
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
```
|
||||
Then to launch every time:
|
||||
```bash
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
```
|
||||
|
||||
#### Uninstall
|
||||
You can uninstall Unsloth Studio by deleting its install folder usually located under `$HOME/.unsloth/studio` on Mac/Linux/WSL and `%USERPROFILE%\.unsloth\studio` on Windows. Using the `rm -rf` commands will **delete everything**, including your history, cache:
|
||||
|
||||
* **MacOS, WSL, Linux:** `rm -rf ~/.unsloth/studio`
|
||||
* **Windows (PowerShell):** `Remove-Item -Recurse -Force "$HOME\.unsloth\studio"`
|
||||
|
||||
For more info, [see our docs](https://unsloth.ai/docs/new/studio/install#uninstall).
|
||||
|
||||
#### Deleting model files
|
||||
|
||||
You can delete old model files either from the bin icon in model search or by removing the relevant cached model folder from the default Hugging Face cache directory. By default, HF uses:
|
||||
|
||||
* **MacOS, Linux, WSL:** `~/.cache/huggingface/hub/`
|
||||
* **Windows:** `%USERPROFILE%\.cache\huggingface\hub\`
|
||||
|
||||
## 💚 Community and Links
|
||||
| Type | Links |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------ |
|
||||
| <img width="16" src="https://cdn.prod.website-files.com/6257adef93867e50d84d30e2/66e3d80db9971f10a9757c99_Symbol.svg" /> **Discord** | [Join Discord server](https://discord.com/invite/unsloth) |
|
||||
| <img width="15" src="https://redditinc.com/hs-fs/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" /> **r/unsloth Reddit** | [Join Reddit community](https://reddit.com/r/unsloth) |
|
||||
| 📚 **Documentation & Wiki** | [Read Our Docs](https://unsloth.ai/docs) |
|
||||
| <img width="13" src="https://upload.wikimedia.org/wikipedia/commons/0/09/X_(formerly_Twitter)_logo_late_2025.svg" /> **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai) |
|
||||
| 🔮 **Our Models** | [Unsloth Catalog](https://unsloth.ai/docs/get-started/unsloth-model-catalog) |
|
||||
| ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog) |
|
||||
|
||||
### Citation
|
||||
|
||||
You can cite the Unsloth repo as follows:
|
||||
```bibtex
|
||||
@software{unsloth,
|
||||
author = {Daniel Han, Michael Han and Unsloth team},
|
||||
title = {Unsloth},
|
||||
url = {https://github.com/unslothai/unsloth},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
If you trained a model with 🦥Unsloth, you can use this cool sticker! <img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/made with unsloth.png" width="200" align="center" />
|
||||
|
||||
### License
|
||||
Unsloth uses a dual-licensing model of Apache 2.0 and AGPL-3.0. The core Unsloth package remains licensed under **[Apache 2.0](https://github.com/unslothai/unsloth?tab=Apache-2.0-1-ov-file)**, while certain optional components, such as the Unsloth Studio UI are licensed under the open-source license **[AGPL-3.0](https://github.com/unslothai/unsloth?tab=AGPL-3.0-2-ov-file)**.
|
||||
|
||||
This structure helps support ongoing Unsloth development while keeping the project open source and enabling the broader ecosystem to continue growing.
|
||||
|
||||
### Thank You to
|
||||
- The [llama.cpp library](https://github.com/ggml-org/llama.cpp) that lets users run and save models with Unsloth
|
||||
- The Hugging Face team and their libraries: [transformers](https://github.com/huggingface/transformers) and [TRL](https://github.com/huggingface/trl)
|
||||
- The Pytorch and [Torch AO](https://github.com/unslothai/unsloth/pull/3391) team for their contributions
|
||||
- NVIDIA for their [NeMo DataDesigner](https://github.com/NVIDIA-NeMo/DataDesigner) library and their contributions
|
||||
- And of course for every single person who has contributed or has used Unsloth!
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
+11898
File diff suppressed because it is too large
Load Diff
+10738
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,512 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a Google Colab A100 instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
#
|
||||
# `FastModel` supports loading nearly any model now! This includes Vision and Text models!
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastModel
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-26B-A4B-it",
|
||||
dtype = None, # None for auto detection
|
||||
max_seq_length = 8192, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
# token = "YOUR_HF_TOKEN", # HF Token for gated models
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can process Text, Vision and Audio!
|
||||
#
|
||||
# Let's first experience how Gemma 4 can handle multimodal inputs. We use Gemma 4's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
from transformers import TextStreamer
|
||||
# Helper function for inference
|
||||
def do_gemma_4_inference(messages, max_new_tokens = 128):
|
||||
_ = model.generate(
|
||||
**tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
return_tensors = "pt",
|
||||
).to("cuda"),
|
||||
max_new_tokens = max_new_tokens,
|
||||
use_cache = True,
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can see images!
|
||||
#
|
||||
# <img src="https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg" alt="Alt text" height="256">
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "image", "image" : sloth_link },
|
||||
{ "type": "text", "text" : "Which films does this animal feature in?" }
|
||||
]
|
||||
}]
|
||||
# You might have to wait 1 minute for Unsloth's auto compiler
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# Let's make a poem about sloths!
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{ "type" : "text",
|
||||
"text" : "Write a poem about sloths." }]
|
||||
}]
|
||||
do_gemma_4_inference(messages)
|
||||
|
||||
|
||||
# # Let's finetune Gemma 4!
|
||||
#
|
||||
# You can finetune the vision and text parts for now through selection - the audio part can also be finetuned - we're working to make it selectable as well!
|
||||
|
||||
# We now add LoRA adapters so we only need to update a small amount of parameters!
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = False, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # Should leave on always!
|
||||
|
||||
r = 8, # Larger = higher accuracy, but might overfit
|
||||
lora_alpha = 8, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We now use the `Gemma-4` format for conversation style finetunes. We use [Maxime Labonne's FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset in ShareGPT style. Gemma-4 renders multi turn conversations like below:
|
||||
#
|
||||
# ```
|
||||
# <bos><|turn>user
|
||||
# Hello<turn|>
|
||||
# <|turn>model
|
||||
# Hey there!<turn|>
|
||||
# ```
|
||||
# We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3, phi4, qwen2.5, gemma3, gemma-4` and more.
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4-thinking",
|
||||
)
|
||||
|
||||
|
||||
# We get the first 3000 rows of the dataset
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("mlabonne/FineTome-100k", split = "train[:3000]")
|
||||
|
||||
|
||||
# We now use `standardize_data_formats` to try converting datasets to the correct format for finetuning purposes!
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import standardize_data_formats
|
||||
dataset = standardize_data_formats(dataset)
|
||||
|
||||
|
||||
# Let's see how row 100 looks like!
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
dataset[100]
|
||||
|
||||
|
||||
# We now have to apply the chat template for `Gemma-3` onto the conversations, and save it to `text`. We remove the `<bos>` token using removeprefix(`'<bos>'`) since we're finetuning. The Processor will add this token before training and the model expects only one.
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversations"]
|
||||
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
|
||||
return { "text" : texts, }
|
||||
|
||||
dataset = dataset.map(formatting_prompts_func, batched = True)
|
||||
|
||||
|
||||
# Let's see how the chat template did! Notice there is no `<bos>` token as the processor tokenizer will be adding one.
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
dataset[100]["text"]
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset,
|
||||
eval_dataset = None, # Can set up evaluation!
|
||||
args = SFTConfig(
|
||||
dataset_text_field = "text",
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4, # Use GA to mimic batch size!
|
||||
warmup_steps = 5,
|
||||
# num_train_epochs = 1, # Set this for 1 full training run.
|
||||
max_steps = 60,
|
||||
learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
|
||||
logging_steps = 1,
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
report_to = "none", # Use TrackIO/WandB etc
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import train_on_responses_only
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|turn>user\n",
|
||||
response_part = "<|turn>model\n",
|
||||
)
|
||||
|
||||
|
||||
# Let's verify masking the instruction part is done! Let's print the 100th row again. Notice how the sample only has a single `<bos>` as expected!
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
tokenizer.decode(trainer.train_dataset[100]["input_ids"])
|
||||
|
||||
|
||||
# Now let's print the masked out example - you should see only the answer is present:
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")
|
||||
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# # Let's train the model!
|
||||
#
|
||||
# To resume a training run, set `trainer.train(resume_from_checkpoint = True)`
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model via Unsloth native inference! According to the `Gemma-3` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4-thinking",
|
||||
)
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type" : "text",
|
||||
"text" : "Continue the sequence: 1, 1, 2, 3, 5, 8,",
|
||||
}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
use_cache = True,
|
||||
# Recommended Gemma-3 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
)
|
||||
tokenizer.batch_decode(outputs)
|
||||
|
||||
|
||||
# You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "Why is the sky blue?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
use_cache = True,
|
||||
# Recommended Gemma-3 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, either use Hugging Face's `push_to_hub` for an online save or `save_pretrained` for a local save.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[23]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# tokenizer.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[24]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastModel
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "What is Gemma-4?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 128, # Increase for longer outputs!
|
||||
# Recommended Gemma-3 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly for deployment! We save it in the folder `gemma-4-finetune`. Set `if False` to `if True` to let it run!
|
||||
|
||||
# In[25]:
|
||||
|
||||
|
||||
if False: # Change to True to save finetune!
|
||||
model.save_pretrained_merged("gemma-4-finetune", tokenizer)
|
||||
|
||||
|
||||
# If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[26]:
|
||||
|
||||
|
||||
if False: # Change to True to upload finetune
|
||||
model.push_to_hub_merged(
|
||||
"HF_ACCOUNT/gemma-4-finetune", tokenizer,
|
||||
token = "YOUR_HF_TOKEN"
|
||||
)
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!
|
||||
|
||||
# In[27]:
|
||||
|
||||
|
||||
if False: # Change to True to save to GGUF
|
||||
model.save_pretrained_gguf(
|
||||
"gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # For now only Q8_0, BF16, F16 supported
|
||||
)
|
||||
|
||||
|
||||
# Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[28]:
|
||||
|
||||
|
||||
if False: # Change to True to upload GGUF
|
||||
model.push_to_hub_gguf(
|
||||
"HF_ACCOUNT/gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # Only Q8_0, BF16, F16 supported
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma-4-finetune.gguf` file or `gemma-4-finetune-Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a Google Colab A100 instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastVisionModel # FastLanguageModel for LLMs
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
|
||||
)
|
||||
|
||||
|
||||
# We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.
|
||||
#
|
||||
# **[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = True, # False if not finetuning vision layers
|
||||
finetune_language_layers = True, # False if not finetuning language layers
|
||||
finetune_attention_modules = True, # False if not finetuning attention layers
|
||||
finetune_mlp_modules = True, # False if not finetuning MLP layers
|
||||
|
||||
r = 32, # The larger, the higher the accuracy, but might overfit
|
||||
lora_alpha = 32, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
target_modules = "all-linear", # Optional now! Can specify a list if needed
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We'll use a sampled dataset of handwritten math formulas. The objective is to convert these images into a computer-readable format—specifically LaTeX—so they can be rendered. This is particularly useful for complex expressions.
|
||||
#
|
||||
# You can access the dataset [here](https://huggingface.co/datasets/unsloth/LaTeX_OCR). The full dataset is [here](https://huggingface.co/datasets/linxy/LaTeX_OCR).
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")
|
||||
|
||||
|
||||
# Let's take an overview of the dataset. We'll examine the second image and its corresponding caption.
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
dataset
|
||||
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
dataset[2]["image"]
|
||||
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
dataset[2]["text"]
|
||||
|
||||
|
||||
# We can also render LaTeX directly in the browser!
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
from IPython.display import display, Math, Latex
|
||||
|
||||
latex = dataset[3]["text"]
|
||||
display(Math(latex))
|
||||
|
||||
|
||||
# To format the dataset, all vision fine-tuning tasks should follow this format:
|
||||
#
|
||||
# ```python
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# ]
|
||||
# ```
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
def convert_to_conversation(sample):
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": instruction},
|
||||
{"type": "image", "image": sample["image"]},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
|
||||
]
|
||||
return {"messages": conversation}
|
||||
pass
|
||||
|
||||
|
||||
# Let's convert the dataset into the "correct" format for finetuning:
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
converted_dataset = [convert_to_conversation(sample) for sample in dataset]
|
||||
|
||||
|
||||
# The first example is now structured like below:
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
converted_dataset[0]
|
||||
|
||||
|
||||
# Lets take the Gemma 4 instruction chat template and use it in our base model
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
from unsloth import get_chat_template
|
||||
|
||||
processor = get_chat_template(
|
||||
processor,
|
||||
"gemma-4-thinking"
|
||||
)
|
||||
|
||||
|
||||
# Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before.
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
image = dataset[2]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# You can see it's absolutely terrible! It doesn't follow instructions at all
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support `DPOTrainer` and `GRPOTrainer` for reinforcement learning!
|
||||
#
|
||||
# We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup.
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
from unsloth.trainer import UnslothVisionDataCollator
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
train_dataset = converted_dataset,
|
||||
processing_class = processor.tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, processor),
|
||||
args = SFTConfig(
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4,
|
||||
max_grad_norm = 0.3,
|
||||
warmup_ratio = 0.03,
|
||||
max_steps = 60,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
learning_rate = 2e-4,
|
||||
logging_steps = 1,
|
||||
save_strategy = "steps",
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "cosine",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none", # For Weights and Biases or others
|
||||
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
max_length = 2048,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model! You can modify the instruction and input—just leave the output blank.
|
||||
#
|
||||
# We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`.
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
image = dataset[10]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
processor.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# processor.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
load_in_4bit = True, # Set to False for 16bit LoRA
|
||||
)
|
||||
|
||||
sample = dataset[1]
|
||||
image = sample["image"].convert("RGB")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sample["text"],
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor.tokenizer, skip_prompt = True)
|
||||
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options.
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
# Select ONLY 1 to save! (Both not needed!)
|
||||
|
||||
# Save locally to 16bit
|
||||
if False: model.save_pretrained_merged("unsloth_finetune", processor,)
|
||||
|
||||
# To export and save to your Hugging Face account
|
||||
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", processor, token = "YOUR_HF_TOKEN")
|
||||
|
||||
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,513 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a Google Colab A100 instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
#
|
||||
# `FastModel` supports loading nearly any model now! This includes Vision and Text models!
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastModel
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-31B-it",
|
||||
dtype = None, # None for auto detection
|
||||
max_seq_length = 8192, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
# token = "YOUR_HF_TOKEN", # HF Token for gated models
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can process Text, Vision and Audio!
|
||||
#
|
||||
# Let's first experience how Gemma 4 can handle multimodal inputs. We use Gemma 4's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
from transformers import TextStreamer
|
||||
# Helper function for inference
|
||||
def do_gemma_4_inference(messages, max_new_tokens = 128):
|
||||
_ = model.generate(
|
||||
**tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
return_tensors = "pt",
|
||||
).to("cuda"),
|
||||
max_new_tokens = max_new_tokens,
|
||||
use_cache = True,
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can see images!
|
||||
#
|
||||
# <img src="https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg" alt="Alt text" height="256">
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "image", "image" : sloth_link },
|
||||
{ "type": "text", "text" : "Which films does this animal feature in?" }
|
||||
]
|
||||
}]
|
||||
# You might have to wait 1 minute for Unsloth's auto compiler
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# Let's make a poem about sloths!
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{ "type" : "text",
|
||||
"text" : "Write a poem about sloths." }]
|
||||
}]
|
||||
do_gemma_4_inference(messages)
|
||||
|
||||
|
||||
# # Let's finetune Gemma 4!
|
||||
#
|
||||
# You can finetune the vision and text parts for now through selection - the audio part can also be finetuned - we're working to make it selectable as well!
|
||||
|
||||
# We now add LoRA adapters so we only need to update a small amount of parameters!
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = False, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # Should leave on always!
|
||||
|
||||
r = 8, # Larger = higher accuracy, but might overfit
|
||||
lora_alpha = 8, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We now use the `Gemma-4` format for conversation style finetunes. We use [Maxime Labonne's FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset in ShareGPT style. Gemma-4 renders multi turn conversations like below:
|
||||
#
|
||||
# ```
|
||||
# <bos><|turn>user
|
||||
# Hello<turn|>
|
||||
# <|turn>model
|
||||
# Hey there!<turn|>
|
||||
# ```
|
||||
# We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3, phi4, qwen2.5, gemma3, gemma-4` and more.
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4-thinking",
|
||||
)
|
||||
|
||||
|
||||
# We get the first 3000 rows of the dataset
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("mlabonne/FineTome-100k", split = "train[:3000]")
|
||||
|
||||
|
||||
# We now use `standardize_data_formats` to try converting datasets to the correct format for finetuning purposes!
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import standardize_data_formats
|
||||
dataset = standardize_data_formats(dataset)
|
||||
|
||||
|
||||
# Let's see how row 100 looks like!
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
dataset[100]
|
||||
|
||||
|
||||
# We now have to apply the chat template for `Gemma-4` onto the conversations, and save it to `text`. We remove the `<bos>` token using removeprefix(`'<bos>'`) since we're finetuning. The Processor will add this token before training and the model expects only one.
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversations"]
|
||||
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
|
||||
return { "text" : texts, }
|
||||
|
||||
dataset = dataset.map(formatting_prompts_func, batched = True)
|
||||
|
||||
|
||||
# Let's see how the chat template did! Notice there is no `<bos>` token as the processor tokenizer will be adding one.
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
dataset[100]["text"]
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset,
|
||||
eval_dataset = None, # Can set up evaluation!
|
||||
args = SFTConfig(
|
||||
dataset_text_field = "text",
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4, # Use GA to mimic batch size!
|
||||
warmup_steps = 5,
|
||||
# num_train_epochs = 1, # Set this for 1 full training run.
|
||||
max_steps = 60,
|
||||
learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
|
||||
logging_steps = 1,
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
report_to = "none", # Use TrackIO/WandB etc
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import train_on_responses_only
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|turn>user\n",
|
||||
response_part = "<|turn>model\n",
|
||||
)
|
||||
|
||||
|
||||
# Let's verify masking the instruction part is done! Let's print the 100th row again. Notice how the sample only has a single `<bos>` as expected!
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
tokenizer.decode(trainer.train_dataset[100]["input_ids"])
|
||||
|
||||
|
||||
# Now let's print the masked out example - you should see only the answer is present:
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")
|
||||
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# # Let's train the model!
|
||||
#
|
||||
# To resume a training run, set `trainer.train(resume_from_checkpoint = True)`
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model via Unsloth native inference! According to the `Gemma-4` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4-thinking",
|
||||
)
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type" : "text",
|
||||
"text" : "Continue the sequence: 1, 1, 2, 3, 5, 8,",
|
||||
}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
use_cache = True,
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
)
|
||||
tokenizer.batch_decode(outputs)
|
||||
|
||||
|
||||
# You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "Why is the sky blue?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
use_cache = True,
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, either use Hugging Face's `push_to_hub` for an online save or `save_pretrained` for a local save.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[23]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# tokenizer.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[24]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastModel
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "What is Gemma-4?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 128, # Increase for longer outputs!
|
||||
use_cache = True,
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly for deployment! We save it in the folder `gemma-4-finetune`. Set `if False` to `if True` to let it run!
|
||||
|
||||
# In[25]:
|
||||
|
||||
|
||||
if False: # Change to True to save finetune!
|
||||
model.save_pretrained_merged("gemma-4-finetune", tokenizer)
|
||||
|
||||
|
||||
# If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[26]:
|
||||
|
||||
|
||||
if False: # Change to True to upload finetune
|
||||
model.push_to_hub_merged(
|
||||
"HF_ACCOUNT/gemma-4-finetune", tokenizer,
|
||||
token = "YOUR_HF_TOKEN"
|
||||
)
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!
|
||||
|
||||
# In[27]:
|
||||
|
||||
|
||||
if False: # Change to True to save to GGUF
|
||||
model.save_pretrained_gguf(
|
||||
"gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # For now only Q8_0, BF16, F16 supported
|
||||
)
|
||||
|
||||
|
||||
# Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[28]:
|
||||
|
||||
|
||||
if False: # Change to True to upload GGUF
|
||||
model.push_to_hub_gguf(
|
||||
"HF_ACCOUNT/gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # Only Q8_0, BF16, F16 supported
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma-4-finetune.gguf` file or `gemma-4-finetune-Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a Google Colab A100 instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastVisionModel # FastLanguageModel for LLMs
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
"unsloth/gemma-4-31B-it",
|
||||
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
|
||||
)
|
||||
|
||||
|
||||
# We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.
|
||||
#
|
||||
# **[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = True, # False if not finetuning vision layers
|
||||
finetune_language_layers = True, # False if not finetuning language layers
|
||||
finetune_attention_modules = True, # False if not finetuning attention layers
|
||||
finetune_mlp_modules = True, # False if not finetuning MLP layers
|
||||
|
||||
r = 32, # The larger, the higher the accuracy, but might overfit
|
||||
lora_alpha = 32, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
target_modules = "all-linear", # Optional now! Can specify a list if needed
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We'll use a sampled dataset of handwritten math formulas. The objective is to convert these images into a computer-readable format—specifically LaTeX—so they can be rendered. This is particularly useful for complex expressions.
|
||||
#
|
||||
# You can access the dataset [here](https://huggingface.co/datasets/unsloth/LaTeX_OCR). The full dataset is [here](https://huggingface.co/datasets/linxy/LaTeX_OCR).
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")
|
||||
|
||||
|
||||
# Let's take an overview of the dataset. We'll examine the second image and its corresponding caption.
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
dataset
|
||||
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
dataset[2]["image"]
|
||||
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
dataset[2]["text"]
|
||||
|
||||
|
||||
# We can also render LaTeX directly in the browser!
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
from IPython.display import display, Math, Latex
|
||||
|
||||
latex = dataset[3]["text"]
|
||||
display(Math(latex))
|
||||
|
||||
|
||||
# To format the dataset, all vision fine-tuning tasks should follow this format:
|
||||
#
|
||||
# ```python
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# ]
|
||||
# ```
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
def convert_to_conversation(sample):
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": instruction},
|
||||
{"type": "image", "image": sample["image"]},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
|
||||
]
|
||||
return {"messages": conversation}
|
||||
pass
|
||||
|
||||
|
||||
# Let's convert the dataset into the "correct" format for finetuning:
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
converted_dataset = [convert_to_conversation(sample) for sample in dataset]
|
||||
|
||||
|
||||
# The first example is now structured like below:
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
converted_dataset[0]
|
||||
|
||||
|
||||
# Lets take the Gemma 4 instruction chat template and use it in our base model
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
from unsloth import get_chat_template
|
||||
|
||||
processor = get_chat_template(
|
||||
processor,
|
||||
"gemma-4-thinking"
|
||||
)
|
||||
|
||||
|
||||
# Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before.
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
image = dataset[2]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# You can see it's absolutely terrible! It doesn't follow instructions at all
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support `DPOTrainer` and `GRPOTrainer` for reinforcement learning!
|
||||
#
|
||||
# We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup.
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
from unsloth.trainer import UnslothVisionDataCollator
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
train_dataset = converted_dataset,
|
||||
processing_class = processor.tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, processor),
|
||||
args = SFTConfig(
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4,
|
||||
max_grad_norm = 0.3,
|
||||
warmup_ratio = 0.03,
|
||||
max_steps = 60,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
learning_rate = 2e-4,
|
||||
logging_steps = 1,
|
||||
save_strategy = "steps",
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "cosine",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none", # For Weights and Biases or others
|
||||
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
max_length = 2048,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model! You can modify the instruction and input—just leave the output blank.
|
||||
#
|
||||
# We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`.
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
image = dataset[10]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
processor.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# processor.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
load_in_4bit = True, # Set to False for 16bit LoRA
|
||||
)
|
||||
|
||||
sample = dataset[1]
|
||||
image = sample["image"].convert("RGB")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sample["text"],
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor.tokenizer, skip_prompt = True)
|
||||
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options.
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
# Select ONLY 1 to save! (Both not needed!)
|
||||
|
||||
# Save locally to 16bit
|
||||
if False: model.save_pretrained_merged("unsloth_finetune", processor,)
|
||||
|
||||
# To export and save to your Hugging Face account
|
||||
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", processor, token = "YOUR_HF_TOKEN")
|
||||
|
||||
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,478 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
#
|
||||
# `FastModel` supports loading nearly any model now! This includes Vision and Text models!
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastModel
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
fourbit_models = [
|
||||
# Gemma 4 models
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, processor = FastModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-E2B-it",
|
||||
dtype = None, # None for auto detection
|
||||
max_seq_length = 8192, # Choose any for long context!
|
||||
load_in_4bit = False, # 4 bit quantization to reduce memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
# token = "YOUR_HF_TOKEN", # HF Token for gated models
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can process Text, Vision and Audio!
|
||||
#
|
||||
# Let's first experience how Gemma 4 can handle multimodal inputs. We use Gemma 4's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64` but for this example we use `do_sample=False` for ASR.
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
from transformers import TextStreamer
|
||||
# Helper function for inference
|
||||
def do_gemma_4_inference(messages, max_new_tokens = 128):
|
||||
_ = model.generate(
|
||||
**processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
return_tensors = "pt",
|
||||
).to("cuda"),
|
||||
max_new_tokens = max_new_tokens,
|
||||
do_sample = False,
|
||||
streamer = TextStreamer(processor, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# <h3>Let's Evaluate Gemma 4 Baseline Performance on German Transcription</h2>
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
from datasets import load_dataset,Audio,concatenate_datasets
|
||||
|
||||
dataset = load_dataset("kadirnar/Emilia-DE-B000000", split = "train")
|
||||
|
||||
# Select a single audio sample to reserve for testing.
|
||||
# This index is chosen from the full dataset before we create the smaller training split.
|
||||
test_audio = dataset[7546]
|
||||
|
||||
dataset = dataset.select(range(3000))
|
||||
|
||||
dataset = dataset.cast_column("audio", Audio(sampling_rate = 16000))
|
||||
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
from IPython.display import Audio, display
|
||||
print(test_audio['text'])
|
||||
Audio(test_audio['audio']['array'],rate = test_audio['audio']['sampling_rate'])
|
||||
|
||||
|
||||
# And the translation of the audio from German to English is:
|
||||
#
|
||||
# > I—I hold myself directly accountable. That much is, of course, clear: namely, that there are political interests involved in trade—in the exchange of goods—and that political influences are at play. The question is: that should not be the alternative.
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": test_audio['audio']['array']},
|
||||
{"type": "text", "text": "Please transcribe this audio."}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# <h3>Baseline Model Performance: 32.43% Word Error Rate (WER) for this sample !</h3>
|
||||
|
||||
# # Let's finetune Gemma 4!
|
||||
#
|
||||
# You can finetune the vision and text and audio parts
|
||||
|
||||
# We now add LoRA adapters so we only need to update a small amount of parameters!
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = False, # False if not finetuning vision layers
|
||||
finetune_language_layers = True, # False if not finetuning language layers
|
||||
finetune_attention_modules = True, # False if not finetuning attention layers
|
||||
finetune_mlp_modules = True, # False if not finetuning MLP layers
|
||||
|
||||
r = 8, # The larger, the higher the accuracy, but might overfit
|
||||
lora_alpha = 16, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
target_modules = [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
|
||||
# Audio layers
|
||||
"post", "linear_start", "linear_end",
|
||||
"embedding_projection",
|
||||
"ffw_layer_1", "ffw_layer_2",
|
||||
"output_proj",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We adapt the `kadirnar/Emilia-DE-B000000` dataset for our German ASR task using Gemma 4 multi-modal chat format. Each audio-text pair is structured into a conversation with `system`, `user`, and `assistant` roles. The processor then converts this into the final training format:
|
||||
#
|
||||
# ```
|
||||
# <bos><|turn>system
|
||||
# You are an assistant that transcribes speech accurately.<turn|>
|
||||
# <|turn>user
|
||||
# <|audio|>Please transcribe this audio.<turn|>
|
||||
# <|turn>model
|
||||
# Ich, ich rechne direkt mich an.<turn|>
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
def format_intersection_data(samples: dict) -> dict[str, list]:
|
||||
"""Format intersection dataset to match expected message format"""
|
||||
formatted_samples = {"messages": []}
|
||||
for idx in range(len(samples["audio"])):
|
||||
audio = samples["audio"][idx]["array"]
|
||||
label = str(samples["text"][idx])
|
||||
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": audio},
|
||||
{"type": "text", "text": "Please transcribe this audio."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content":[{"type": "text", "text": label}]
|
||||
}
|
||||
]
|
||||
formatted_samples["messages"].append(message)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
dataset = dataset.map(format_intersection_data, batched = True, batch_size = 4, num_proc = 4)
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
# Use UnslothVisionDataCollator which handles audio token alignment correctly
|
||||
from unsloth.trainer import UnslothVisionDataCollator
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
train_dataset = dataset,
|
||||
processing_class = processor.tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, processor),
|
||||
args = SFTConfig(
|
||||
per_device_train_batch_size = 8,
|
||||
gradient_accumulation_steps = 1,
|
||||
warmup_ratio = 0.03,
|
||||
# num_train_epochs = 1, # Use for full training runs
|
||||
max_steps = 60,
|
||||
learning_rate = 5e-5,
|
||||
logging_steps = 1,
|
||||
save_strategy = "steps",
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "cosine",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
remove_unused_columns = False,
|
||||
|
||||
# The below are a must for audio finetuning:
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
max_length = 8192,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# # Let's train the model!
|
||||
#
|
||||
# To resume a training run, set `trainer.train(resume_from_checkpoint = True)`
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model via Unsloth native inference! According to the `Gemma-4` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64` but for this example we use `do_sample=False` for ASR.
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": test_audio['audio']['array']},
|
||||
{"type": "text", "text": "Please transcribe this audio."}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, either use Hugging Face's `push_to_hub` for an online save or `save_pretrained` for a local save.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
processor.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# processor.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastModel
|
||||
model, processor = FastModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "What is Gemma-4?",}]
|
||||
}]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 128, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(processor, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly for deployment! We save it in the folder `gemma-4-finetune`. Set `if False` to `if True` to let it run!
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
if False: # Change to True to save finetune!
|
||||
model.save_pretrained_merged("gemma-4", processor)
|
||||
|
||||
|
||||
# If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
if False: # Change to True to upload finetune
|
||||
model.push_to_hub_merged(
|
||||
"HF_ACCOUNT/gemma-4-finetune", processor,
|
||||
token = "YOUR_HF_TOKEN"
|
||||
)
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
if False: # Change to True to save to GGUF
|
||||
model.save_pretrained_gguf(
|
||||
"gemma_4_finetune",
|
||||
processor,
|
||||
quantization_method = "Q8_0", # For now only Q8_0, BF16, F16 supported
|
||||
)
|
||||
|
||||
|
||||
# Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
if False: # Change to True to upload GGUF
|
||||
model.push_to_hub_gguf(
|
||||
"HF_ACCOUNT/gemma_4_finetune",
|
||||
processor,
|
||||
quantization_method = "Q8_0", # Only Q8_0, BF16, F16 supported
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma-4-finetune.gguf` file or `gemma-4-finetune-Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,556 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
#
|
||||
# `FastModel` supports loading nearly any model now! This includes Vision and Text models!
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastModel
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-E2B-it",
|
||||
dtype = None, # None for auto detection
|
||||
max_seq_length = 1024, # Choose any for long context!
|
||||
load_in_4bit = False, # 4 bit quantization to reduce memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
# token = "YOUR_HF_TOKEN", # HF Token for gated models
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can process Text, Vision and Audio!
|
||||
#
|
||||
# Let's first experience how Gemma 4 can handle multimodal inputs. We use Gemma 4's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
from transformers import TextStreamer
|
||||
# Helper function for inference
|
||||
def do_gemma_4_inference(messages, max_new_tokens = 128):
|
||||
_ = model.generate(
|
||||
**tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
return_tensors = "pt",
|
||||
).to("cuda"),
|
||||
max_new_tokens = max_new_tokens,
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True)
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can see images!
|
||||
#
|
||||
# <img src="https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg" alt="Alt text" height="256">
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "image", "image" : sloth_link },
|
||||
{ "type": "text", "text" : "Which films does this animal feature in?" }
|
||||
]
|
||||
}]
|
||||
# You might have to wait 1 minute for Unsloth's auto compiler
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# Let's make a poem about sloths!
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{ "type" : "text",
|
||||
"text" : "Write a poem about sloths." }]
|
||||
}]
|
||||
do_gemma_4_inference(messages)
|
||||
|
||||
|
||||
# # Gemma 4 can also hear!
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
from IPython.display import Audio, display
|
||||
Audio("https://www.nasa.gov/wp-content/uploads/2015/01/591240main_JFKmoonspeech.mp3")
|
||||
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
get_ipython().system('wget -qqq https://www.nasa.gov/wp-content/uploads/2015/01/591240main_JFKmoonspeech.mp3 -O audio.mp3')
|
||||
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
audio_file = "audio.mp3"
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "audio", "audio" : audio_file },
|
||||
{ "type": "text", "text" : "What is this audio about?" }
|
||||
]
|
||||
}]
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# # Let's combine all 3 modalities together!
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "audio", "audio" : audio_file },
|
||||
{ "type": "image", "image" : sloth_link },
|
||||
{ "type": "text", "text" : "What is this audio and image about? "\
|
||||
"How are they related?" }
|
||||
]
|
||||
}]
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# # Let's finetune Gemma 4!
|
||||
#
|
||||
# You can finetune the vision and text parts for now through selection - the audio part can also be finetuned - we're working to make it selectable as well!
|
||||
|
||||
# We now add LoRA adapters so we only need to update a small amount of parameters!
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = False, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # Should leave on always!
|
||||
|
||||
r = 8, # Larger = higher accuracy, but might overfit
|
||||
lora_alpha = 8, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We now use the `Gemma-4` format for conversation style finetunes. We use [Maxime Labonne's FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset in ShareGPT style. Gemma-4 renders multi turn conversations like below:
|
||||
#
|
||||
# ```
|
||||
# <bos><|turn>user
|
||||
# Hello<turn|>
|
||||
# <|turn>model
|
||||
# Hey there!<turn|>
|
||||
# ```
|
||||
# We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3, phi4, qwen2.5, gemma3, gemma-4` and more.
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4",
|
||||
)
|
||||
|
||||
|
||||
# We get the first 3000 rows of the dataset
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("mlabonne/FineTome-100k", split = "train[:3000]")
|
||||
|
||||
|
||||
# We now use `standardize_data_formats` to try converting datasets to the correct format for finetuning purposes!
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import standardize_data_formats
|
||||
dataset = standardize_data_formats(dataset)
|
||||
|
||||
|
||||
# Let's see how row 100 looks like!
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
dataset[100]
|
||||
|
||||
|
||||
# We now have to apply the chat template for `Gemma-4` onto the conversations, and save it to `text`. We remove the `<bos>` token using removeprefix(`'<bos>'`) since we're finetuning. The Processor will add this token before training and the model expects only one.
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversations"]
|
||||
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
|
||||
return { "text" : texts, }
|
||||
|
||||
dataset = dataset.map(formatting_prompts_func, batched = True)
|
||||
|
||||
|
||||
# Let's see how the chat template did! Notice there is no `<bos>` token as the processor tokenizer will be adding one.
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
dataset[100]["text"]
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset,
|
||||
eval_dataset = None, # Can set up evaluation!
|
||||
args = SFTConfig(
|
||||
dataset_text_field = "text",
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4, # Use GA to mimic batch size!
|
||||
warmup_steps = 5,
|
||||
# num_train_epochs = 1, # Set this for 1 full training run.
|
||||
max_steps = 60,
|
||||
learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
|
||||
logging_steps = 1,
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
report_to = "none", # Use TrackIO/WandB etc
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import train_on_responses_only
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|turn>user\n",
|
||||
response_part = "<|turn>model\n",
|
||||
)
|
||||
|
||||
|
||||
# Let's verify masking the instruction part is done! Let's print the 100th row again. Notice how the sample only has a single `<bos>` as expected!
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
tokenizer.decode(trainer.train_dataset[100]["input_ids"])
|
||||
|
||||
|
||||
# Now let's print the masked out example - you should see only the answer is present:
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")
|
||||
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# # Let's train the model!
|
||||
#
|
||||
# To resume a training run, set `trainer.train(resume_from_checkpoint = True)`
|
||||
|
||||
# In[23]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[24]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model via Unsloth native inference! According to the `Gemma-4` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[25]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4",
|
||||
)
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type" : "text",
|
||||
"text" : "Continue the sequence: 1, 1, 2, 3, 5, 8,",
|
||||
}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
)
|
||||
tokenizer.batch_decode(outputs)
|
||||
|
||||
|
||||
# You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!
|
||||
|
||||
# In[26]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "Why is the sky blue?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, either use Hugging Face's `push_to_hub` for an online save or `save_pretrained` for a local save.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[27]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# tokenizer.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[28]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastModel
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "What is Gemma-4?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 128, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly for deployment! We save it in the folder `gemma-4-finetune`. Set `if False` to `if True` to let it run!
|
||||
|
||||
# In[29]:
|
||||
|
||||
|
||||
if False: # Change to True to save finetune!
|
||||
model.save_pretrained_merged("gemma-4-finetune", tokenizer)
|
||||
|
||||
|
||||
# If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[30]:
|
||||
|
||||
|
||||
if False: # Change to True to upload finetune
|
||||
model.push_to_hub_merged(
|
||||
"HF_ACCOUNT/gemma-4-finetune", tokenizer,
|
||||
token = "YOUR_HF_TOKEN"
|
||||
)
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!
|
||||
|
||||
# In[31]:
|
||||
|
||||
|
||||
if False: # Change to True to save to GGUF
|
||||
model.save_pretrained_gguf(
|
||||
"gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # For now only Q8_0, BF16, F16 supported
|
||||
)
|
||||
|
||||
|
||||
# Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[32]:
|
||||
|
||||
|
||||
if False: # Change to True to upload GGUF
|
||||
model.push_to_hub_gguf(
|
||||
"HF_ACCOUNT/gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # Only Q8_0, BF16, F16 supported
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma-4-finetune.gguf` file or `gemma-4-finetune-Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[ ]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[ ]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import FastVisionModel # FastLanguageModel for LLMs
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
load_in_4bit = False, # Use 4bit to reduce memory use. False for 16bit LoRA.
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
|
||||
)
|
||||
|
||||
|
||||
# We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.
|
||||
#
|
||||
# **[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = True, # False if not finetuning vision layers
|
||||
finetune_language_layers = True, # False if not finetuning language layers
|
||||
finetune_attention_modules = True, # False if not finetuning attention layers
|
||||
finetune_mlp_modules = True, # False if not finetuning MLP layers
|
||||
|
||||
r = 32, # The larger, the higher the accuracy, but might overfit
|
||||
lora_alpha = 32, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
target_modules = "all-linear", # Optional now! Can specify a list if needed
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We'll use a sampled dataset of handwritten math formulas. The objective is to convert these images into a computer-readable format—specifically LaTeX—so they can be rendered. This is particularly useful for complex expressions.
|
||||
#
|
||||
# You can access the dataset [here](https://huggingface.co/datasets/unsloth/LaTeX_OCR). The full dataset is [here](https://huggingface.co/datasets/linxy/LaTeX_OCR).
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")
|
||||
|
||||
|
||||
# Let's take an overview of the dataset. We'll examine the second image and its corresponding caption.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
dataset
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
dataset[2]["image"]
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
dataset[2]["text"]
|
||||
|
||||
|
||||
# We can also render LaTeX directly in the browser!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from IPython.display import display, Math, Latex
|
||||
|
||||
latex = dataset[3]["text"]
|
||||
display(Math(latex))
|
||||
|
||||
|
||||
# To format the dataset, all vision fine-tuning tasks should follow this format:
|
||||
#
|
||||
# ```python
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# ]
|
||||
# ```
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
def convert_to_conversation(sample):
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": instruction},
|
||||
{"type": "image", "image": sample["image"]},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
|
||||
]
|
||||
return {"messages": conversation}
|
||||
pass
|
||||
|
||||
|
||||
# Let's convert the dataset into the "correct" format for finetuning:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
converted_dataset = [convert_to_conversation(sample) for sample in dataset]
|
||||
|
||||
|
||||
# The first example is now structured like below:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
converted_dataset[0]
|
||||
|
||||
|
||||
# Lets take the Gemma 4 instruction chat template and use it in our base model
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import get_chat_template
|
||||
|
||||
processor = get_chat_template(
|
||||
processor,
|
||||
"gemma-4"
|
||||
)
|
||||
|
||||
|
||||
# Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
image = dataset[2]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# You can see it's absolutely terrible! It doesn't follow instructions at all
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support `DPOTrainer` and `GRPOTrainer` for reinforcement learning!
|
||||
#
|
||||
# We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth.trainer import UnslothVisionDataCollator
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
train_dataset = converted_dataset,
|
||||
processing_class = processor.tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, processor),
|
||||
args = SFTConfig(
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4,
|
||||
max_grad_norm = 0.3,
|
||||
warmup_ratio = 0.03,
|
||||
max_steps = 60,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
learning_rate = 2e-4,
|
||||
logging_steps = 1,
|
||||
save_strategy = "steps",
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "cosine",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none", # For Weights and Biases or others
|
||||
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
max_length = 2048,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model! You can modify the instruction and input—just leave the output blank.
|
||||
#
|
||||
# We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
image = dataset[10]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
processor.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# processor.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
load_in_4bit = True, # Set to False for 16bit LoRA
|
||||
)
|
||||
|
||||
sample = dataset[1]
|
||||
image = sample["image"].convert("RGB")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sample["text"],
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor.tokenizer, skip_prompt = True)
|
||||
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Select ONLY 1 to save! (Both not needed!)
|
||||
|
||||
# Save locally to 16bit
|
||||
if False: model.save_pretrained_merged("unsloth_finetune", processor,)
|
||||
|
||||
# To export and save to your Hugging Face account
|
||||
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", processor, token = "YOUR_HF_TOKEN")
|
||||
|
||||
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,911 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[ ]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[ ]:
|
||||
#
|
||||
#
|
||||
# #@title Colab Extra Install { display-mode: "form" }
|
||||
# get_ipython().run_line_magic('%capture', '')
|
||||
# import os
|
||||
# get_ipython().system('pip install --upgrade -qqq uv')
|
||||
# if "COLAB_" not in "".join(os.environ.keys()):
|
||||
# # If you're not in Colab, just use pip install!
|
||||
# get_ipython().system('pip install unsloth vllm')
|
||||
# else:
|
||||
# try: import numpy, PIL; _numpy = f'numpy=={numpy.__version__}'; _pil = f'pillow=={PIL.__version__}'
|
||||
# except: _numpy = "numpy"; _pil = "pillow"
|
||||
# try: import subprocess; is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
|
||||
# except: is_t4 = False
|
||||
# _vllm, _triton = ('vllm==0.9.2', 'triton==3.2.0') if is_t4 else ('vllm==0.15.1', 'triton')
|
||||
# get_ipython().system('uv pip install -qqq --upgrade {_vllm} {_numpy} {_pil} torchvision bitsandbytes xformers unsloth')
|
||||
# get_ipython().system('uv pip install -qqq {_triton}')
|
||||
# get_ipython().system('uv pip install transformers==4.56.2')
|
||||
# get_ipython().system('uv pip install --no-deps trl==0.22.2')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
|
||||
# # Goal: Make faster kernels with Reinforcement Learning
|
||||
#
|
||||
# Our goal is to make a faster matrix multiplication kernel by doing RL on Gemma 4 with Unsloth.
|
||||
#
|
||||
# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/18/Matrix_multiplication_qtl1.svg/500px-Matrix_multiplication_qtl1.svg.png" height=200 />
|
||||
#
|
||||
# You will learn how to:
|
||||
# 1. Counteract **reward hacking** like cheating, caching, laziness.
|
||||
# 2. Timing and correctness of kernels and time limits.
|
||||
# 3. Making good **reward functions**
|
||||
# 4. How to seriously do RL to make optimized kernels
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import FastVisionModel
|
||||
import torch
|
||||
max_seq_length = 4096 # Can increase for longer reasoning traces
|
||||
lora_rank = 32 # Larger rank = smarter, but slower
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-E2B-it",
|
||||
max_seq_length = max_seq_length,
|
||||
load_in_4bit = False, # False for LoRA 16bit
|
||||
fast_inference = False, # Enable vllm fast inference
|
||||
)
|
||||
|
||||
|
||||
# We now add some small amount of LoRA weights to Gemma 4 so we only need to train those, instead of training on the full model.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
lora_alpha = lora_rank*2, # *2 speeds up training
|
||||
use_gradient_checkpointing = "unsloth", # Reduces memory usage
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
|
||||
# # Optimized matrix multiplication
|
||||
#
|
||||
# Numpy has optimized matrix multiplication kernels for CPUs via BLAS optimized operations. For GPUs, one can use CUDA accelerated cuBLAS kernels which PyTorch calls under the hood.
|
||||
#
|
||||
# To generate some random matrices to do matrix multiplication, we can do the below:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
import numpy as np
|
||||
def generate_random_matrices(seed = 3407, n = 256):
|
||||
random_state = np.random.RandomState(seed)
|
||||
n, k, m = random_state.randint(1, n+1, size = 3)
|
||||
A = np.random.uniform(-10, 10, size = (n, k))
|
||||
B = np.random.uniform(-10, 10, size = (k, m))
|
||||
return A, A.tolist(), B, B.tolist()
|
||||
|
||||
|
||||
# We shall generate a small matrix, and see the matrix multiplied output
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
A, A_list, B, B_list = generate_random_matrices(seed = 42, n = 5)
|
||||
print(A)
|
||||
print(B)
|
||||
print(np.matmul(A, B))
|
||||
|
||||
|
||||
# We can call a LLM to generate a simple matrix multiply kernel in Python only, and we can calculate the differences between the actual result and the kernel's result
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def calculate_difference(pred, real):
|
||||
if pred is None: return 5, 5
|
||||
assert real is not None
|
||||
import numpy as np
|
||||
try:
|
||||
difference = pred - real
|
||||
except:
|
||||
return 5, 5
|
||||
amax_error = float(np.amax(difference))
|
||||
mse_error = float(np.mean(np.square(difference)))
|
||||
return amax_error, mse_error
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Kernel generated by GPT-5
|
||||
def matmul(A, B):
|
||||
z, s = zip, sum
|
||||
Bt = list(z(*B))
|
||||
return [[s(a*b for a, b in z(row, col)) for col in Bt] for row in A]
|
||||
|
||||
|
||||
# We see the error below is very small, so that's good!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
prediction = matmul(A_list, B_list)
|
||||
calculate_difference(prediction, np.matmul(A, B))
|
||||
|
||||
|
||||
# # Countering Reward Hacking
|
||||
#
|
||||
# The ultimate goal of RL is to maximize some reward (say speed, revenue, some metric).
|
||||
#
|
||||
# But RL can **cheat** When the RL algorithm learns a trick or exploits something to increase the reward, without actually doing the task at end, this is called "Reward Hacking".
|
||||
#
|
||||
# Some good examples are in https://en.wikipedia.org/wiki/Reward_hacking
|
||||
#
|
||||
# For matrix multiplication kernels, we might see the following issues:
|
||||
#
|
||||
# * Laziness: RL learns to use Numpy, Torch, other libraries, which calls optimized kernels.
|
||||
# * Caching: RL learns to cache the result of the output
|
||||
# * Cheating: RL learns to find the actual output by inspecting Python global variables
|
||||
# * RL learns to edit the timing function to make it output 0 time as passed.
|
||||
#
|
||||
# And possibly more. We shall try to address each!
|
||||
|
||||
# # Countering Reward Hacking 1: Stop laziness
|
||||
# We can stop the RL algorithm from calling optimized code by inspecting if the generated code imports other non standard Python libraries. We used GPT-5 to help generate this check `check_only_stdlib_imports`:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
#@title (Collapsible code)
|
||||
import ast
|
||||
import sys
|
||||
import sysconfig
|
||||
from pathlib import Path
|
||||
|
||||
def _stdlib_names():
|
||||
"""
|
||||
Build a set of canonical stdlib top-level module/package names.
|
||||
Uses sys.stdlib_module_names when available (3.10+), with a
|
||||
filesystem fallback for older versions/edge cases.
|
||||
"""
|
||||
names = {m.lower() for m in getattr(sys, "stdlib_module_names", set())}
|
||||
names |= {m.lower() for m in sys.builtin_module_names}
|
||||
names.add("__future__") # special-case
|
||||
|
||||
# Fallback/augmentation: scan the stdlib directory
|
||||
try:
|
||||
stdlib_dir = Path(sysconfig.get_path("stdlib"))
|
||||
if stdlib_dir.exists():
|
||||
for p in stdlib_dir.iterdir():
|
||||
if p.name == "site-packages":
|
||||
continue
|
||||
if p.suffix == ".py":
|
||||
names.add(p.stem.lower())
|
||||
elif p.is_dir() and (p / "__init__.py").exists():
|
||||
names.add(p.name.lower())
|
||||
except Exception:
|
||||
# conservative fallback; the names set above will still work well
|
||||
pass
|
||||
|
||||
return names
|
||||
|
||||
_STDLIB_SET = _stdlib_names()
|
||||
|
||||
def check_only_stdlib_imports(code: str):
|
||||
"""
|
||||
Return (ok: bool, details: dict)
|
||||
|
||||
ok == True -> all absolute imports are from the stdlib.
|
||||
ok == False -> details['non_stdlib'] lists offending top-level modules.
|
||||
|
||||
details includes:
|
||||
- stdlib: sorted list of stdlib imports found
|
||||
- non_stdlib: sorted list of non-stdlib imports found
|
||||
- relative_imports: count of relative imports (always allowed here)
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return False, {
|
||||
"error": f"SyntaxError: {e}",
|
||||
"stdlib": [],
|
||||
"non_stdlib": [],
|
||||
"relative_imports": 0,
|
||||
}
|
||||
|
||||
abs_imports = set()
|
||||
relative_count = 0
|
||||
|
||||
class Visitor(ast.NodeVisitor):
|
||||
def visit_Import(self, node: ast.Import):
|
||||
for alias in node.names:
|
||||
abs_imports.add(alias.name.split(".")[0])
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
nonlocal relative_count
|
||||
if (node.level or 0) > 0:
|
||||
# relative import
|
||||
relative_count += 1
|
||||
else:
|
||||
if node.module:
|
||||
abs_imports.add(node.module.split(".")[0])
|
||||
|
||||
Visitor().visit(tree)
|
||||
|
||||
stdlib_found = sorted(m for m in abs_imports if m.lower() in _STDLIB_SET)
|
||||
non_stdlib = sorted(m for m in abs_imports if m.lower() not in _STDLIB_SET)
|
||||
|
||||
return len(non_stdlib) == 0, {
|
||||
"stdlib": stdlib_found,
|
||||
"non_stdlib": non_stdlib,
|
||||
"relative_imports": relative_count,
|
||||
}
|
||||
|
||||
|
||||
# For example, let's call `check_only_stdlib_imports` on a random piece of matrix multiplication code generated by GPT-5:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
sample = """
|
||||
def matmul(A, B):
|
||||
import numpy as np
|
||||
from torch import matmul
|
||||
z, s = zip, sum
|
||||
Bt = list(z(*B))
|
||||
return [[s(a*b for a, b in z(row, col)) for col in Bt] for row in A]
|
||||
"""
|
||||
ok, info = check_only_stdlib_imports(sample)
|
||||
print("Only stdlib imports?", ok)
|
||||
print(info)
|
||||
|
||||
|
||||
# # Countering Reward Hacking 2: Stop cheating
|
||||
# We can stop the RL algorithm from using global or cached variables by restricting it's `locals` and `globals`.
|
||||
#
|
||||
# We are also going to use `exec` to create the function, so we have to save the output to an empty dict.
|
||||
#
|
||||
# We also disallow global variable access.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
output_function = {}
|
||||
exec(sample, {}, output_function)
|
||||
output_function["matmul"]
|
||||
|
||||
|
||||
# We also disallow global variable access via `types.FunctionType(f.__code__, {})`
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
import types
|
||||
output_function["matmul"] = types.FunctionType(output_function["matmul"].__code__, {})
|
||||
|
||||
def import_numpy():
|
||||
np.matmul
|
||||
print("Success")
|
||||
|
||||
import_numpy()
|
||||
import_numpy = types.FunctionType(import_numpy.__code__, {})
|
||||
try:
|
||||
import_numpy()
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def create_locked_down_function(function):
|
||||
output_function = {}
|
||||
exec(function, {}, output_function)
|
||||
new_matmul = output_function["matmul"]
|
||||
new_matmul = types.FunctionType(new_matmul.__code__, {})
|
||||
return new_matmul
|
||||
|
||||
|
||||
# # Countering Reward Hacking 3: Stop caching
|
||||
# We can stop the RL algorithm from using cached data by wiping the cache with a large fake matrix. We also have to benchmark carefully with multiple loops and turns.
|
||||
#
|
||||
# We also add a **timer** to not make the algorithm go in an endless loop.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
import os, gc, time, statistics
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
class TimeoutError(Exception): pass
|
||||
|
||||
@contextmanager
|
||||
def time_limit(seconds):
|
||||
def _handler(signum, frame):
|
||||
raise TimeoutError(f"Timed out after {seconds}s")
|
||||
old = signal.signal(signal.SIGALRM, _handler)
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0.0)
|
||||
signal.signal(signal.SIGALRM, old)
|
||||
|
||||
class Benchmarker:
|
||||
def __init__(self, trials = 3, loops = 1, timeout = 30):
|
||||
self.buffer = np.zeros(2 * 1024 * 1024 * 1024, dtype = np.uint8)
|
||||
self.trials = trials
|
||||
self.loops = loops
|
||||
assert timeout > 0 # Cannot be 0 since it won't work!
|
||||
self.timeout = timeout
|
||||
def thrash(self):
|
||||
# Edit the buffer to wipe cache lines
|
||||
self.buffer ^= 1
|
||||
return int(self.buffer[::4096].sum())
|
||||
|
||||
def benchmark(self, function, arguments):
|
||||
assert len(arguments) == self.loops
|
||||
samples = []
|
||||
exceptions = []
|
||||
timed_out = 0
|
||||
for _ in range(self.trials):
|
||||
gc.collect(); gc.disable(); self.thrash()
|
||||
t_start = time.perf_counter_ns()
|
||||
for i in range(self.loops):
|
||||
try:
|
||||
with time_limit(self.timeout):
|
||||
function(*arguments[i])
|
||||
except TimeoutError as e:
|
||||
timed_out += 1
|
||||
except Exception as e:
|
||||
exceptions.append(str(e))
|
||||
t_end = time.perf_counter_ns()
|
||||
gc.enable()
|
||||
samples.append((t_end - t_start) // max(1, self.loops))
|
||||
return {
|
||||
"median_ns": int(statistics.median(samples)),
|
||||
"mean_ns": int(statistics.fmean(samples)),
|
||||
"stdev_ns": int(statistics.pstdev(samples) if len(samples) > 1 else 0),
|
||||
"exceptions" : exceptions,
|
||||
"timeouts" : timed_out,
|
||||
}
|
||||
|
||||
|
||||
# For example we use our matmul kernel we had, and benchmark it with a 10 second delay:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
A, A_list, B, B_list = generate_random_matrices(seed = 0, n = 256)
|
||||
Benchmarker(trials = 1, timeout = 10).benchmark(output_function["matmul"], [(A_list, B_list)])
|
||||
|
||||
|
||||
# # Data & RL task setup
|
||||
#
|
||||
# We now have to create a prompt to the model for which it will do some task. For our matrix multiply example, we use the below:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
prompt = """
|
||||
Create a new fast matrix multiplication function using only native Python code.
|
||||
You are given a list of list of numbers.
|
||||
Output your new function in backticks using the format below:
|
||||
```python
|
||||
def matmul(A, B):
|
||||
return ...
|
||||
```
|
||||
""".strip()
|
||||
print(prompt)
|
||||
|
||||
|
||||
# First, let's prompt Gemma 4 without RL and see how it goes:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt.strip()}],
|
||||
tokenize = False,
|
||||
add_generation_prompt = True,
|
||||
)
|
||||
|
||||
from transformers import TextStreamer
|
||||
print("=" * 50)
|
||||
print("BASE MODEL OUTPUT (before RL training):")
|
||||
print("=" * 50)
|
||||
|
||||
inputs = tokenizer(
|
||||
text = text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# # Reward functions
|
||||
#
|
||||
# We now design the `extract_function` function which simply extracts the function wrapped in 3 backticks.
|
||||
#
|
||||
# And 4 reward functions:
|
||||
#
|
||||
# 1. `function_works` which rewards the model if the strategy is a valid Python function.
|
||||
# 2. `no_cheating` which checks if the function imported other modules, and if it did, we penalize it.
|
||||
# 3. `correctness_check` which checks if the kernel was correct or wrong - it shouldn't generate gibberish!
|
||||
# 4. `speed_check` checks the performance relative to Numpy matmul directly.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def extract_function(text):
|
||||
if text.count("```") >= 2:
|
||||
first = text.find("```") + 3
|
||||
second = text.find("```", first)
|
||||
fx = text[first : second].strip()
|
||||
fx = fx.removeprefix("python\n")
|
||||
fx = fx[fx.find("def"):]
|
||||
if fx.startswith("def matmul(A, B):"): return fx
|
||||
return None
|
||||
print(extract_function(prompt))
|
||||
|
||||
|
||||
# Below is our `function_works` reward function which uses Python's `exec` but guarded by not allowing leakage of local and global variables. We can also use `check_only_stdlib_imports` first to check if there are errors before even executing the function:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
ok, info = check_only_stdlib_imports("def a")
|
||||
ok, info
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def function_works(completions, **kwargs):
|
||||
scores = []
|
||||
for completion in completions:
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
print(function)
|
||||
if function is not None:
|
||||
ok, info = check_only_stdlib_imports(function)
|
||||
if function is None or "error" in info:
|
||||
score = -2.0
|
||||
else:
|
||||
try:
|
||||
new_matmul = create_locked_down_function(function)
|
||||
score = 1.0
|
||||
except:
|
||||
score = -0.5
|
||||
scores.append(score)
|
||||
return scores
|
||||
|
||||
|
||||
# `no_cheating` checks if the function cheated since it might have imported Numpy or Torch optimized code.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def no_cheating(completions, **kwargs):
|
||||
scores = []
|
||||
for completion in completions:
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
if function is not None:
|
||||
ok, info = check_only_stdlib_imports(function)
|
||||
else:
|
||||
ok = False
|
||||
scores.append(1.0 if ok else -20.0) # Penalize heavily!
|
||||
return scores
|
||||
|
||||
|
||||
# Next `correctness_check` checks if the kernel was correct. We want to penalize if the absolute error is larger than 1, and if the mean squared error is somewhat bigger then machine epsilon.
|
||||
#
|
||||
# We have to execute the code now!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
np.finfo(np.float64).eps
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def correctness_check(completions, **kwargs):
|
||||
scores = []
|
||||
# Generate some random matrices of size less than 128
|
||||
A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 128)
|
||||
for completion in completions:
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
if function is not None:
|
||||
ok, info = check_only_stdlib_imports(function)
|
||||
if function is None or "error" in info:
|
||||
scores.append(0)
|
||||
continue
|
||||
try:
|
||||
new_matmul = create_locked_down_function(function)
|
||||
except:
|
||||
scores.append(0)
|
||||
continue
|
||||
try:
|
||||
pred = new_matmul(A_list.copy(), B_list.copy())
|
||||
except:
|
||||
# Failed!
|
||||
scores.append(-2.0)
|
||||
continue
|
||||
true = np.matmul(A, B)
|
||||
amax_error, mse_error = calculate_difference(pred, true)
|
||||
|
||||
# Check correctness and score!
|
||||
machine_epsilon = 100*np.finfo(np.float64).eps
|
||||
if amax_error >= 3: score = -3.0
|
||||
elif amax_error >= 2: score = -2.5
|
||||
elif amax_error >= 1: score = -2.0
|
||||
elif amax_error >= 0.5: score = -1.0
|
||||
elif amax_error >= 100*machine_epsilon: score = 0.0
|
||||
elif amax_error >= machine_epsilon: score = 1.0
|
||||
else: score = 3.0
|
||||
|
||||
if mse_error >= 3: score += -3.0
|
||||
elif mse_error >= 2: score += -2.5
|
||||
elif mse_error >= 1: score += -2.0
|
||||
elif mse_error >= 0.5: score += -1.0
|
||||
elif mse_error >= 100*machine_epsilon: score += 0.0
|
||||
elif mse_error >= machine_epsilon: score += 1.0
|
||||
else: score += 3.0
|
||||
scores.append(score)
|
||||
return scores
|
||||
|
||||
|
||||
# Finally our benchmarking function for `speed_check`! We shall limit the timer to 10 seconds and do 3 trials.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
A, A_list, B, B_list = generate_random_matrices(seed = 0, n = 256)
|
||||
benchmarker = Benchmarker(trials = 3, timeout = 10)
|
||||
numpy_results = benchmarker.benchmark(np.matmul, [(A, B)])
|
||||
numpy_results
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
new_matmul = create_locked_down_function(extract_function(prompt))
|
||||
new_results = benchmarker.benchmark(new_matmul, [(A_list, B_list)])
|
||||
new_results
|
||||
|
||||
|
||||
# We can take the difference and do a negative sign for slower ones. If the ratio is less than 1 (ie faster, we shall invert it!)
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
negative = -(new_results["median_ns"] / numpy_results["median_ns"]) / 100
|
||||
positive = +(numpy_results["median_ns"] / new_results["median_ns"]) / 100
|
||||
reward = negative if new_results["median_ns"] >= numpy_results["median_ns"] else positive
|
||||
reward
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
new_results["median_ns"] = 3
|
||||
numpy_results["median_ns"] = 1000
|
||||
negative = -(new_results["median_ns"] / numpy_results["median_ns"]) / 100
|
||||
positive = +(numpy_results["median_ns"] / new_results["median_ns"]) / 100
|
||||
reward = negative if new_results["median_ns"] >= numpy_results["median_ns"] else positive
|
||||
reward
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
import gc
|
||||
def speed_check(completions, **kwargs):
|
||||
scores = []
|
||||
# Generate some random matrices of size less than 256
|
||||
A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 256)
|
||||
numpy_results = benchmarker.benchmark(np.matmul, [(A, B)])
|
||||
for completion in completions:
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
if function is not None:
|
||||
ok, info = check_only_stdlib_imports(function)
|
||||
if function is None or "error" in info:
|
||||
scores.append(0)
|
||||
continue
|
||||
try:
|
||||
new_matmul = create_locked_down_function(function)
|
||||
except:
|
||||
scores.append(0)
|
||||
continue
|
||||
new_results = benchmarker.benchmark(new_matmul, [(A_list.copy(), B_list.copy())])
|
||||
|
||||
# Get score and clip to -10, 10
|
||||
negative = -(new_results["median_ns"] / numpy_results["median_ns"]) / 100
|
||||
positive = +(numpy_results["median_ns"] / new_results["median_ns"]) / 100
|
||||
score = negative if new_results["median_ns"] >= numpy_results["median_ns"] else positive
|
||||
if score >= 10: score = 10
|
||||
if score <= -10: score = -10
|
||||
scores.append(score)
|
||||
# Free memory to counteract OOMs
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return scores
|
||||
|
||||
|
||||
# We create the dataset which includes a replica of our prompt.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from datasets import Dataset
|
||||
dataset = Dataset.from_list([{"prompt" : [{"role": "user", "content": prompt.strip()}], "answer" : 0}]*1000)
|
||||
maximum_length = len(tokenizer.apply_chat_template([{"role":"user", "content":prompt.strip()}], add_generation_prompt = True, tokenize = True))
|
||||
print(maximum_length)
|
||||
dataset[0]
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
#
|
||||
# Now set up GRPO Trainer and all configurations! We also support GSDP, GAPO, Dr GRPO and more! Go to our docs https://unsloth.ai/docs/ for more info!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Leave room for the prompt (plus 1 token safety margin)
|
||||
max_completion_length = max_seq_length - (maximum_length + 1)
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
training_args = GRPOConfig(
|
||||
temperature = 1.0,
|
||||
top_p = 0.95,
|
||||
top_k = 64,
|
||||
learning_rate = 5e-5,
|
||||
weight_decay = 0.001,
|
||||
warmup_ratio = 0.1,
|
||||
lr_scheduler_type = "linear",
|
||||
optim = "adamw_8bit",
|
||||
logging_steps = 1,
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 2, # Increase to 4 for smoother training
|
||||
num_generations = 2, # Decrease if out of memory
|
||||
max_completion_length = max_completion_length,
|
||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
||||
max_steps = 100,
|
||||
save_steps = 100,
|
||||
report_to = "none", # Can use Weights & Biases, TrackIO
|
||||
output_dir = "outputs",
|
||||
epsilon = 0.2,
|
||||
epsilon_high = 0.28, # one sided
|
||||
delta = 1.5, # two sided
|
||||
loss_type = 'bnpo',
|
||||
mask_truncated_completions = True
|
||||
# For optional training + evaluation
|
||||
# fp16_full_eval = True,
|
||||
# per_device_eval_batch_size = 4,
|
||||
# eval_accumulation_steps = 1,
|
||||
# eval_strategy = "steps",
|
||||
# eval_steps = 1,
|
||||
)
|
||||
|
||||
|
||||
# And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!
|
||||
#
|
||||
# You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!
|
||||
#
|
||||
# | Step | Training Loss | reward | reward_std | completion_length | kl |
|
||||
# |------|---------------|-----------|------------|-------------------|----------|
|
||||
# | 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |
|
||||
# | 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |
|
||||
# | 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# For optional training + evaluation
|
||||
# new_dataset = dataset.train_test_split(test_size = 0.01)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model = model,
|
||||
processing_class = tokenizer,
|
||||
reward_funcs = [
|
||||
function_works,
|
||||
no_cheating,
|
||||
correctness_check,
|
||||
speed_check,
|
||||
],
|
||||
args = training_args,
|
||||
train_dataset = dataset,
|
||||
|
||||
# For optional training + evaluation
|
||||
# train_dataset = new_dataset["train"],
|
||||
# eval_dataset = new_dataset["test"],
|
||||
)
|
||||
|
||||
|
||||
# And let's train the model!
|
||||
#
|
||||
# **NOTE** A T4 free GPU might take 5 minutes for one generation sadly since it's an old GPU - A100 or H100 will be much faster!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
# And now with the LoRA we just trained with GRPO - we first save the LoRA first!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
|
||||
|
||||
# Verify LoRA is actually trained!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
tensors = {}
|
||||
with safe_open("grpo_saved_lora/adapter_model.safetensors", framework = "pt") as f:
|
||||
# Verify both A and B are non zero
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
n_zeros = (tensor == 0).sum() / tensor.numel()
|
||||
assert(n_zeros.item() != tensor.numel())
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# # Inference
|
||||
# Now let's try the model we just trained!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt.strip()}],
|
||||
tokenize = False,
|
||||
add_generation_prompt = True,
|
||||
)
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
_ = model.generate(
|
||||
**tokenizer(images = None, text = text, return_tensors = "pt").to("cuda"),
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
max_new_tokens = 1024,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = False),
|
||||
)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Merge to 16bit
|
||||
if False: model.save_pretrained_merged("gemma_4_finetune_16bit", tokenizer, save_method = "merged_16bit",)
|
||||
if False: model.push_to_hub_merged("HF_USERNAME/gemma_4_finetune_16bit", tokenizer, save_method = "merged_16bit", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Merge to 4bit
|
||||
if False: model.save_pretrained_merged("gemma_4_finetune_4bit", tokenizer, save_method = "merged_4bit",)
|
||||
if False: model.push_to_hub_merged("HF_USERNAME/gemma_4_finetune_4bit", tokenizer, save_method = "merged_4bit", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Just LoRA adapters
|
||||
if False:
|
||||
model.save_pretrained("gemma_4_lora")
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
if False:
|
||||
model.push_to_hub("HF_USERNAME/gemma_4_lora", token = "YOUR_HF_TOKEN")
|
||||
tokenizer.push_to_hub("HF_USERNAME/gemma_4_lora", token = "YOUR_HF_TOKEN")
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.
|
||||
#
|
||||
# Some supported quant methods (full list on our [docs page](https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf)):
|
||||
# * `q8_0` - Fast conversion. High resource use, but generally acceptable.
|
||||
# * `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
|
||||
# * `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.
|
||||
#
|
||||
# [**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Save to 8bit Q8_0
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer,)
|
||||
# Remember to go to https://huggingface.co/settings/tokens for a token!
|
||||
# And change hf to your username!
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to 16bit GGUF
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer, quantization_method = "f16")
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, quantization_method = "f16", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to q4_k_m GGUF
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer, quantization_method = "q4_k_m")
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, quantization_method = "q4_k_m", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to multiple GGUF options - much faster if you want multiple!
|
||||
if False:
|
||||
model.push_to_hub_gguf(
|
||||
"HF_USERNAME/gemma_4_finetune", # Change hf to your username!
|
||||
tokenizer,
|
||||
quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma_4_finetune.Q8_0.gguf` file or `gemma_4_finetune.Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
+913
@@ -0,0 +1,913 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# # Goal: Make Gemma 4 play games with Reinforcement Learning
|
||||
#
|
||||
# Our goal is to make Gemma 4 play the 2048 game with reinforcement learning, or a variant of it called [GRPO](https://arxiv.org/abs/2501.12948).
|
||||
#
|
||||
# We want the model to devise a strategy to play 2048, and we will run this strategy until we win or lose. We then reward the model if it created a good strategy (winning the game), and we'll penalize it (negative reward) if the strategy was a bad one.
|
||||
#
|
||||
# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/2048_win.png/500px-2048_win.png" height=300 />
|
||||
|
||||
# # Installation
|
||||
# We'll be using [Unsloth](https://github.com/unslothai/unsloth) to do RL on Gemma 4. Unsloth saves 70% VRAM usage and makes reinforcement learning 2 to 6x faster!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
get_ipython().run_cell_magic('capture', '', 'import os, importlib.util\n!pip install --upgrade -qqq uv\nif importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):\n try: import numpy, PIL; _numpy = f"numpy=={numpy.__version__}"; _pil = f"pillow=={PIL.__version__}"\n except: _numpy = "numpy"; _pil = "pillow"\n # Gemma 4 requires transformers >= 5.5.0 — do NOT pin to 4.x here\n !uv pip install -qqq \\\n "torch>=2.8.0" "triton>=3.4.0" {_numpy} {_pil} torchvision bitsandbytes \\\n "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \\\n "unsloth[base] @ git+https://github.com/unslothai/unsloth" \\\n git+https://github.com/triton-lang/triton.git@0add68262ab0a2e33b84524346cb27cbb2787356#subdirectory=python/triton_kernels\nelif importlib.util.find_spec("unsloth") is None:\n !uv pip install -qqq unsloth\n# Gemma 4 requires transformers >= 5.5.0\n!uv pip install --upgrade --no-deps "transformers>=5.5.0" tokenizers "trl>=0.28.0" unsloth unsloth_zoo\n')
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
|
||||
|
||||
# ### Unsloth
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import FastVisionModel
|
||||
import torch
|
||||
max_seq_length = 4096 # Can increase for longer reasoning traces
|
||||
lora_rank = 32 # Larger rank = smarter, but slower
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-E2B-it",
|
||||
max_seq_length = max_seq_length,
|
||||
load_in_4bit = False, # False for LoRA 16bit
|
||||
fast_inference = False, # Enable vllm fast inference
|
||||
)
|
||||
|
||||
|
||||
# To do efficient RL, we will use [LoRA](https://arxiv.org/abs/2106.09685), which allows us to only add 1 to 5% of extra weights to the model for finetuning purposes. This allows us to save memory usage by over 60%, and yet it retains good accuracy.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
lora_alpha = lora_rank*2, # *2 speeds up training
|
||||
use_gradient_checkpointing = "unsloth", # Reduces memory usage
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
|
||||
# # 2048 game
|
||||
#
|
||||
# We used GPT-5 to create a variant of the 2048 game. It should output the current game board state, and allow us to advance the game board state with 1 action (up, down, left, right).
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
#@title (Collapsible) 2048 Game Implementation
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple, Optional
|
||||
import random
|
||||
import copy
|
||||
|
||||
def _compress_and_merge_row_left(row: List[int]) -> Tuple[List[int], int, bool]:
|
||||
n = len(row)
|
||||
tiles = [x for x in row if x != 0]
|
||||
gained = 0
|
||||
i = 0
|
||||
merged = []
|
||||
while i < len(tiles):
|
||||
if i + 1 < len(tiles) and tiles[i] == tiles[i + 1]:
|
||||
v = tiles[i] * 2
|
||||
gained += v
|
||||
merged.append(v)
|
||||
i += 2
|
||||
else:
|
||||
merged.append(tiles[i])
|
||||
i += 1
|
||||
merged += [0] * (n - len(merged))
|
||||
changed = merged != row
|
||||
return merged, gained, changed
|
||||
|
||||
def _move_left(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:
|
||||
changed_any = False
|
||||
total_gain = 0
|
||||
new_board = []
|
||||
for row in board:
|
||||
new_row, gained, changed = _compress_and_merge_row_left(row)
|
||||
new_board.append(new_row)
|
||||
total_gain += gained
|
||||
changed_any = changed_any or changed
|
||||
return new_board, total_gain, changed_any
|
||||
|
||||
def _move_right(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:
|
||||
changed_any = False
|
||||
total_gain = 0
|
||||
new_board = []
|
||||
for row in board:
|
||||
rev = list(reversed(row))
|
||||
new_rev, gained, changed = _compress_and_merge_row_left(rev)
|
||||
new_row = list(reversed(new_rev))
|
||||
new_board.append(new_row)
|
||||
total_gain += gained
|
||||
changed_any = changed_any or changed
|
||||
return new_board, total_gain, changed_any
|
||||
|
||||
def _transpose(board: List[List[int]]) -> List[List[int]]:
|
||||
return [list(row) for row in zip(*board)]
|
||||
|
||||
def _move_up(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:
|
||||
t = _transpose(board)
|
||||
moved, gain, changed = _move_left(t)
|
||||
return _transpose(moved), gain, changed
|
||||
|
||||
def _move_down(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:
|
||||
t = _transpose(board)
|
||||
moved, gain, changed = _move_right(t)
|
||||
return _transpose(moved), gain, changed
|
||||
|
||||
def _empty_cells(board: List[List[int]]) -> List[Tuple[int, int]]:
|
||||
size = len(board)
|
||||
return [(r, c) for r in range(size) for c in range(size) if board[r][c] == 0]
|
||||
|
||||
def _can_move(board: List[List[int]]) -> bool:
|
||||
if _empty_cells(board):
|
||||
return True
|
||||
size = len(board)
|
||||
for r in range(size):
|
||||
for c in range(size - 1):
|
||||
if board[r][c] == board[r][c + 1]:
|
||||
return True
|
||||
for r in range(size - 1):
|
||||
for c in range(size):
|
||||
if board[r][c] == board[r + 1][c]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@dataclass
|
||||
class GameBoard:
|
||||
size: int
|
||||
seed: Optional[int] = None
|
||||
target: int = 2048
|
||||
probability_fours: float = 0.10 # originally spawns (4) 10% of the time!
|
||||
_rng: random.Random = field(init = False, repr = False)
|
||||
_board: List[List[int]] = field(init = False, repr = False)
|
||||
_score: int = field(default = 0, init = False, repr = False)
|
||||
_state: str = field(default = "ongoing", init = False, repr = False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.size < 2:
|
||||
raise ValueError("Board size must be at least 2.")
|
||||
self._rng = random.Random(self.seed)
|
||||
self._board = [[0 for _ in range(self.size)] for _ in range(self.size)]
|
||||
self._add_random_tile()
|
||||
self._add_random_tile()
|
||||
self._update_state_after_change()
|
||||
|
||||
class _BoardView:
|
||||
def __init__(self, game: "GameBoard"):
|
||||
self._game = game
|
||||
def __iter__(self):
|
||||
return iter(self._game._board)
|
||||
def __len__(self):
|
||||
return len(self._game._board)
|
||||
def __getitem__(self, idx):
|
||||
return self._game._board[idx]
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._game._board)
|
||||
__str__ = __repr__
|
||||
def do_action(self, key: str) -> None:
|
||||
self._game.do_action(key)
|
||||
def state(self) -> str:
|
||||
return self._game.state()
|
||||
def pretty(self, colors: bool = True, border: bool = True, dot_for_zero: bool = True) -> str:
|
||||
return self._game._render_pretty(colors = colors, border = border, dot_for_zero = dot_for_zero)
|
||||
|
||||
def board(self) -> "_BoardView":
|
||||
return GameBoard._BoardView(self)
|
||||
def state(self) -> str:
|
||||
return self._state
|
||||
def score(self) -> int:
|
||||
return self._score
|
||||
def do_action(self, key: str) -> None:
|
||||
if self._state != "ongoing":
|
||||
return
|
||||
if not isinstance(key, str) or len(key) == 0:
|
||||
self._state = "failed"
|
||||
return
|
||||
k = key.strip().lower()
|
||||
if k == "q":
|
||||
self._state = "failed"
|
||||
return
|
||||
move_map = {"a": _move_left, "d": _move_right, "w": _move_up, "s": _move_down}
|
||||
if k not in move_map:
|
||||
self._state = "failed"
|
||||
return
|
||||
mover = move_map[k]
|
||||
new_board, gain, changed = mover(self._board)
|
||||
if changed:
|
||||
self._board = new_board
|
||||
self._score += gain
|
||||
self._add_random_tile()
|
||||
self._update_state_after_change()
|
||||
def _add_random_tile(self) -> bool:
|
||||
empties = _empty_cells(self._board)
|
||||
if not empties:
|
||||
return False
|
||||
r, c = self._rng.choice(empties)
|
||||
self._board[r][c] = 4 if self._rng.random() < self.probability_fours else 2
|
||||
return True
|
||||
def _update_state_after_change(self) -> None:
|
||||
if any(self.target in row for row in self._board):
|
||||
self._state = "success"
|
||||
return
|
||||
if not _can_move(self._board):
|
||||
self._state = "failed"
|
||||
return
|
||||
self._state = "ongoing"
|
||||
def _render_pretty(self, colors: bool = True, border: bool = True, dot_for_zero: bool = True) -> str:
|
||||
"""
|
||||
Pretty-print the board with colors that scale from 0 up to self.target.
|
||||
Uses ANSI 256-color codes (works in most terminals). Set colors = False to disable.
|
||||
"""
|
||||
import math
|
||||
|
||||
b = self._board
|
||||
mx = max((max(row) for row in b), default = 0)
|
||||
cell_w = max(3, len(str(mx)))
|
||||
|
||||
RESET = "\x1b[0m"
|
||||
|
||||
# A smooth-ish gradient from cool → warm
|
||||
# (blue/cyan/green → yellow/orange/red). Tweak or expand as you like.
|
||||
GRAD = [33, 39, 45, 51, 50, 49, 48, 47, 46, 82, 118, 154, 190, 226, 220, 214, 208, 202, 196]
|
||||
ZERO_FG = 239 # dim gray
|
||||
|
||||
def color_code(v: int) -> str:
|
||||
if not colors:
|
||||
return ""
|
||||
if v == 0:
|
||||
return f"\x1b[38;5;{ZERO_FG}m"
|
||||
# Normalize by exponent relative to target: r in [0,1]
|
||||
t = max(2, self.target) # safety; avoid log2(1)
|
||||
# Guard: if v is not a power of two or is <1, handle gracefully
|
||||
try:
|
||||
r = max(0.0, min(1.0, math.log2(v) / math.log2(t)))
|
||||
except ValueError:
|
||||
r = 0.0
|
||||
idx = int(round(r * (len(GRAD) - 1)))
|
||||
return f"\x1b[38;5;{GRAD[idx]}m"
|
||||
|
||||
def fmt(v: int) -> str:
|
||||
s = "." if (v == 0 and dot_for_zero) else str(v)
|
||||
s = s.rjust(cell_w)
|
||||
return color_code(v) + s + (RESET if colors else "")
|
||||
|
||||
def hline(left: str, mid: str, right: str) -> str:
|
||||
return left + mid.join("─" * cell_w for _ in range(self.size)) + right
|
||||
|
||||
rows = []
|
||||
if border:
|
||||
rows.append(hline("┌", "┬", "┐"))
|
||||
for r in range(self.size):
|
||||
content = "│".join(fmt(v) for v in b[r])
|
||||
rows.append(("│" + content + "│") if border else content)
|
||||
if border:
|
||||
rows.append(hline("└" if r == self.size - 1 else "├",
|
||||
"┴" if r == self.size - 1 else "┼",
|
||||
"┘" if r == self.size - 1 else "┤"))
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
# For example let's create a board of size 5 X 5 and set the target to 8 instead of 2048.
|
||||
#
|
||||
# **[NOTE]** 2048 originally spawns a (4) 10% of the time! We can disable this for harder games. See [Wikipedia page](https://en.wikipedia.org/wiki/2048_(video_game)) for more details.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game = GameBoard(size = 5, seed = 42, target = 8, probability_fours = 0.10)
|
||||
print(game.board().pretty(), game.state())
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game
|
||||
|
||||
|
||||
# We'll use WASD for the action space:
|
||||
#
|
||||
# ```
|
||||
# W
|
||||
# A S D
|
||||
# ```
|
||||
# Also `game.state()` will say `success` if we succeeded in getting the target!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game.do_action("A")
|
||||
print(game.board().pretty(), game.state())
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game.do_action("W")
|
||||
print(game.board().pretty(), game.state())
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game.do_action("D")
|
||||
print(game.board().pretty(), game.state())
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game.do_action("W")
|
||||
print(game.board().pretty(), game.state())
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game.do_action("D")
|
||||
print(game.board().pretty(), game.state())
|
||||
|
||||
|
||||
# If we do some other action that's not part of the action space, we will get an error, and the game will not accept anymore actions.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game = GameBoard(size = 3, seed = 42, target = 8, probability_fours = 0.10)
|
||||
game.do_action("AA") # Not in WASD
|
||||
game.do_action("W") # Doesn't do anything
|
||||
game.do_action("A") # Doesn't do anything
|
||||
print(game.board().pretty(), game.state())
|
||||
|
||||
|
||||
# # RL Environment Setup
|
||||
#
|
||||
# We'll set up a function to accept some strategy that'll emit an action within `WASD` and check the game state.
|
||||
#
|
||||
# We'll also add a timer to only execute the strategy for 2 seconds maximum, otherwise it might never terminate!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from typing import Callable
|
||||
from unsloth import execute_with_time_limit
|
||||
|
||||
def _execute_strategy(strategy : Callable, game : GameBoard):
|
||||
assert callable(strategy)
|
||||
|
||||
steps = 0
|
||||
while game.state() == "ongoing":
|
||||
action = strategy(list(game.board()))
|
||||
steps += 1
|
||||
if type(action) is not str:
|
||||
return steps, "failed"
|
||||
game.do_action(action)
|
||||
return steps, game.state()
|
||||
|
||||
@execute_with_time_limit(2)
|
||||
def execute_strategy(strategy : Callable, game : GameBoard):
|
||||
return _execute_strategy(strategy, game)
|
||||
|
||||
|
||||
# Let's make a generic strategy to just hit `W`. We should expect this generic strategy to fail:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def always_move_left(board):
|
||||
return "W"
|
||||
|
||||
game = GameBoard(size = 8, seed = 42, target = 2048, probability_fours = 0.10)
|
||||
try:
|
||||
execute_strategy(always_move_left, game)
|
||||
except TimeoutError as e:
|
||||
print(f"Timed out with error = {str(e)}")
|
||||
|
||||
|
||||
# To allow longer strategies for Gemma 4 Reinforcement Learning, we shall allow a 5 second timer.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
@execute_with_time_limit(5)
|
||||
def execute_strategy(strategy : Callable, game : GameBoard):
|
||||
return _execute_strategy(strategy, game)
|
||||
|
||||
|
||||
# # Code Execution
|
||||
#
|
||||
# To execute and create a new Python function, we first have to check if the function does not call other global variables or cheat. This is called `countering reward hacking` since we don't want the function to cheat.
|
||||
#
|
||||
# For example the below piece of code is fine, since it only imports Python level functions. We use `check_python_modules`:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import check_python_modules
|
||||
|
||||
sample = """
|
||||
def strategy(board):
|
||||
import math
|
||||
from typing import Callable
|
||||
return "W"
|
||||
"""
|
||||
ok, info = check_python_modules(sample)
|
||||
print("Only Python imports?", ok)
|
||||
print(info)
|
||||
|
||||
|
||||
# For the below piece of code, since we import `numpy`, we should not allow the execution:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
sample = """
|
||||
def strategy(board):
|
||||
from numpy import matmul
|
||||
return "W"
|
||||
"""
|
||||
ok, info = check_python_modules(sample)
|
||||
print("Only Python imports?", ok)
|
||||
print(info)
|
||||
|
||||
|
||||
# We also disallow global variable access. We'll use Unsloth's `create_locked_down_function` function
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import create_locked_down_function
|
||||
function = """
|
||||
def import_numpy():
|
||||
np.matmul
|
||||
print("Success")
|
||||
"""
|
||||
f = create_locked_down_function(function)
|
||||
try:
|
||||
f()
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import create_locked_down_function
|
||||
function = """
|
||||
def add(a, b):
|
||||
def adder(a):
|
||||
return a + b
|
||||
return adder(b) + b
|
||||
"""
|
||||
f = create_locked_down_function(function)
|
||||
try:
|
||||
print(f(10, 20))
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
# # Data & RL task setup
|
||||
#
|
||||
# We now have to create a prompt to tell the model to create a strategy for the 2048 game. You can customize this to some other task for another RL task.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
prompt = """
|
||||
Create a new short 2048 strategy using only native Python code.
|
||||
You are given a list of list of numbers for the current board state.
|
||||
Output one action for "W", "A", "S", "D" on what is the optimal next step.
|
||||
Output your new short function in backticks using the format below:
|
||||
```python
|
||||
def strategy(board):
|
||||
return "W" # Example
|
||||
```
|
||||
All helper functions should be inside def strategy. Only output the short function `strategy`.
|
||||
""".strip()
|
||||
print(prompt)
|
||||
|
||||
|
||||
# First, let's prompt Gemma 4 without RL and see how it goes:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt.strip()}],
|
||||
tokenize = False,
|
||||
add_generation_prompt = True,
|
||||
)
|
||||
|
||||
from transformers import TextStreamer
|
||||
print("=" * 50)
|
||||
print("BASE MODEL OUTPUT (before RL training):")
|
||||
print("=" * 50)
|
||||
|
||||
inputs = tokenizer(
|
||||
text = text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# # Reward functions
|
||||
#
|
||||
# We now design a `extract_function` function which simply extracts the function wrapped in 3 back ticks.
|
||||
#
|
||||
# And 3 reward functions:
|
||||
#
|
||||
# 1. `function_works` which rewards the model if the strategy is a valid Python function.
|
||||
# 2. `no_cheating` which checks if the function imported other modules, and if it did, we penalize it.
|
||||
# 3. `strategy_succeeds` which checks if the game strategy actually succeeds in attaining 2048 after running the auto-generated strategy.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def extract_function(text):
|
||||
if text.count("```") >= 2:
|
||||
first = text.find("```") + 3
|
||||
second = text.find("```", first)
|
||||
fx = text[first : second].strip()
|
||||
fx = fx.removeprefix("python\n")
|
||||
fx = fx[fx.find("def"):]
|
||||
if fx.startswith("def strategy(board):"): return fx
|
||||
return None
|
||||
print(extract_function(prompt))
|
||||
|
||||
|
||||
# Below is our `function_works` reward function which uses Python's `exec` but guarded by not allowing leakage of local and global variables. We can also use `check_python_modules` first to check if there are errors before even executing the function:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
ok, info = check_python_modules("def a")
|
||||
ok, info
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def function_works(completions, **kwargs):
|
||||
scores = []
|
||||
for completion in completions:
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
if function is not None:
|
||||
ok, info = check_python_modules(function)
|
||||
if function is None or "error" in info:
|
||||
score = -2.0
|
||||
else:
|
||||
try:
|
||||
new_strategy = create_locked_down_function(function)
|
||||
score = 1.0
|
||||
except:
|
||||
score = -0.5
|
||||
scores.append(score)
|
||||
return scores
|
||||
|
||||
|
||||
# `no_cheating` checks if the function cheated since it might have imported Numpy or other functions:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def no_cheating(completions, **kwargs):
|
||||
scores = []
|
||||
for completion in completions:
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
if function is not None:
|
||||
ok, info = check_python_modules(function)
|
||||
scores.append(1.0 if ok else -20.0) # Penalize heavily!
|
||||
else:
|
||||
scores.append(-1.0) # Failed creating function
|
||||
return scores
|
||||
|
||||
|
||||
# Next `strategy_succeeds` checks if the strategy actually allows the game to terminate. Imagine if the strategy simply returned "W" which would fail after a time limit of 10 seconds.
|
||||
#
|
||||
# We also add a global `PRINTER` to print out the strategy and board state.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
import numpy as np
|
||||
global PRINTER
|
||||
PRINTER = 0
|
||||
def strategy_succeeds(completions, **kwargs):
|
||||
global PRINTER
|
||||
scores = []
|
||||
# Generate a random game board with seed
|
||||
seed = np.random.randint(10000)
|
||||
for completion in completions:
|
||||
printed = False
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
if PRINTER % 5 == 0:
|
||||
printed = True
|
||||
print(function)
|
||||
PRINTER += 1
|
||||
if function is not None:
|
||||
ok, info = check_python_modules(function)
|
||||
if function is None or "error" in info:
|
||||
scores.append(0)
|
||||
continue
|
||||
try:
|
||||
new_strategy = create_locked_down_function(function)
|
||||
except:
|
||||
scores.append(0)
|
||||
continue
|
||||
try:
|
||||
game = GameBoard(size = 6, seed = seed, target = 2048, probability_fours = 0.10)
|
||||
steps, game_state = execute_strategy(new_strategy, game)
|
||||
print(f"Steps = {steps} State = {game_state}")
|
||||
if printed is False:
|
||||
print(function)
|
||||
print(game.board().pretty())
|
||||
if game_state == "success":
|
||||
scores.append(20.0) # Success - massively reward!
|
||||
else:
|
||||
scores.append(2.0) # Failed but function works!
|
||||
except TimeoutError as e:
|
||||
print("Timeout")
|
||||
scores.append(-1.0) # Failed with timeout
|
||||
except Exception as e:
|
||||
print(f"Exception = {str(e)}")
|
||||
scores.append(-3.0) # Failed
|
||||
return scores
|
||||
|
||||
|
||||
# We'll now create the dataset which includes a replica of our prompt.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from datasets import Dataset
|
||||
dataset = Dataset.from_list([{"prompt" : [{"role": "user", "content": prompt.strip()}], "answer" : 0}]*1000)
|
||||
maximum_length = len(tokenizer.apply_chat_template([{"role":"user", "content":prompt.strip()}], add_generation_prompt = True, tokenize = True))
|
||||
print(maximum_length)
|
||||
dataset[0]
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
#
|
||||
# Now set up GRPO Trainer and all configurations! We also support GSPO, GAPO, Dr GRPO and more! Go the Unsloth [Reinforcement Learning Docs](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) for more options.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Leave room for the prompt (plus 1 token safety margin)
|
||||
max_completion_length = max_seq_length - (maximum_length + 1)
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
training_args = GRPOConfig(
|
||||
temperature = 1.0,
|
||||
top_p = 0.95,
|
||||
top_k = 64,
|
||||
learning_rate = 5e-5,
|
||||
weight_decay = 0.001,
|
||||
warmup_ratio = 0.1,
|
||||
lr_scheduler_type = "linear",
|
||||
optim = "adamw_8bit",
|
||||
logging_steps = 1,
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 2, # Increase to 4 for smoother training
|
||||
num_generations = 2, # Decrease if out of memory
|
||||
max_completion_length = max_completion_length,
|
||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
||||
max_steps = 60,
|
||||
save_steps = 100,
|
||||
report_to = "none", # Can use Weights & Biases, TrackIO
|
||||
output_dir = "outputs",
|
||||
epsilon = 0.2,
|
||||
epsilon_high = 0.28, # one sided
|
||||
delta = 1.5, # two sided
|
||||
loss_type = 'bnpo',
|
||||
mask_truncated_completions = True
|
||||
# For optional training + evaluation
|
||||
# fp16_full_eval = True,
|
||||
# per_device_eval_batch_size = 4,
|
||||
# eval_accumulation_steps = 1,
|
||||
# eval_strategy = "steps",
|
||||
# eval_steps = 1,
|
||||
)
|
||||
|
||||
|
||||
# And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!
|
||||
#
|
||||
# You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!
|
||||
#
|
||||
# | Step | Training Loss | reward | reward_std | completion_length | kl |
|
||||
# |------|---------------|-----------|------------|-------------------|----------|
|
||||
# | 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |
|
||||
# | 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |
|
||||
# | 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# For optional training + evaluation
|
||||
# new_dataset = dataset.train_test_split(test_size = 0.01)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model = model,
|
||||
processing_class = tokenizer,
|
||||
reward_funcs = [
|
||||
function_works,
|
||||
no_cheating,
|
||||
strategy_succeeds,
|
||||
],
|
||||
args = training_args,
|
||||
train_dataset = dataset,
|
||||
|
||||
# For optional training + evaluation
|
||||
# train_dataset = new_dataset["train"],
|
||||
# eval_dataset = new_dataset["test"],
|
||||
)
|
||||
|
||||
|
||||
# And let's train the model!
|
||||
#
|
||||
# **NOTE** A T4 free GPU might take 5 minutes for one generation sadly since it's an old GPU - A100 or H100 will be much faster!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
# And now with the LoRA we just trained with GRPO - we first save the LoRA first!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
|
||||
|
||||
# Verify LoRA is actually trained!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
tensors = {}
|
||||
with safe_open("grpo_saved_lora/adapter_model.safetensors", framework = "pt") as f:
|
||||
# Verify both A and B are non zero
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
n_zeros = (tensor == 0).sum() / tensor.numel()
|
||||
assert(n_zeros.item() != tensor.numel())
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# # Inference
|
||||
# Now let's try the model we just trained!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt.strip()}],
|
||||
tokenize = False,
|
||||
add_generation_prompt = True,
|
||||
)
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
_ = model.generate(
|
||||
**tokenizer(images = None, text = text, return_tensors = "pt").to("cuda"),
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
max_new_tokens = 1024,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = False),
|
||||
)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Merge to 16bit
|
||||
if False: model.save_pretrained_merged("gemma_4_finetune_16bit", tokenizer, save_method = "merged_16bit",)
|
||||
if False: model.push_to_hub_merged("HF_USERNAME/gemma_4_finetune_16bit", tokenizer, save_method = "merged_16bit", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Merge to 4bit
|
||||
if False: model.save_pretrained_merged("gemma_4_finetune_4bit", tokenizer, save_method = "merged_4bit",)
|
||||
if False: model.push_to_hub_merged("HF_USERNAME/gemma_4_finetune_4bit", tokenizer, save_method = "merged_4bit", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Just LoRA adapters
|
||||
if False:
|
||||
model.save_pretrained("gemma_4_lora")
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
if False:
|
||||
model.push_to_hub("HF_USERNAME/gemma_4_lora", token = "YOUR_HF_TOKEN")
|
||||
tokenizer.push_to_hub("HF_USERNAME/gemma_4_lora", token = "YOUR_HF_TOKEN")
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.
|
||||
#
|
||||
# Some supported quant methods (full list on our [docs page](https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf)):
|
||||
# * `q8_0` - Fast conversion. High resource use, but generally acceptable.
|
||||
# * `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
|
||||
# * `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.
|
||||
#
|
||||
# [**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Save to 8bit Q8_0
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer,)
|
||||
# Remember to go to https://huggingface.co/settings/tokens for a token!
|
||||
# And change hf to your username!
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to 16bit GGUF
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer, quantization_method = "f16")
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, quantization_method = "f16", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to q4_k_m GGUF
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer, quantization_method = "q4_k_m")
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, quantization_method = "q4_k_m", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to multiple GGUF options - much faster if you want multiple!
|
||||
if False:
|
||||
model.push_to_hub_gguf(
|
||||
"HF_USERNAME/gemma_4_finetune", # Change hf to your username!
|
||||
tokenizer,
|
||||
quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma_4_finetune.Q8_0.gguf` file or `gemma_4_finetune.Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
+897
@@ -0,0 +1,897 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# # Goal: Make Gemma 4 solve Sudoku puzzles with Reinforcement Learning
|
||||
#
|
||||
# Our goal is to make Gemma 4 learn to solve Sudoku puzzles using reinforcement learning (GRPO).
|
||||
# The model will devise a strategy to fill in empty cells, and we'll reward it for correct placements
|
||||
# and completing valid puzzles.
|
||||
#
|
||||
# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/12/Sudoku_Puzzle_by_L2G-20050714_solution_standardized_layout.svg/1280px-Sudoku_Puzzle_by_L2G-20050714_solution_standardized_layout.svg.png" height="300" />
|
||||
|
||||
# # Installation
|
||||
# We'll be using [Unsloth](https://github.com/unslothai/unsloth) to do RL on Gemma 4. Unsloth saves 70% VRAM usage and makes reinforcement learning 2 to 6x faster.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
get_ipython().run_cell_magic('capture', '', 'import os, importlib.util\n!pip install --upgrade -qqq uv\nif importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):\n try: import numpy, PIL; _numpy = f"numpy=={numpy.__version__}"; _pil = f"pillow=={PIL.__version__}"\n except: _numpy = "numpy"; _pil = "pillow"\n # Gemma 4 requires transformers >= 5.5.0 — do NOT pin to 4.x here\n !uv pip install -qqq \\\n "torch>=2.8.0" "triton>=3.4.0" {_numpy} {_pil} torchvision bitsandbytes \\\n "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \\\n "unsloth[base] @ git+https://github.com/unslothai/unsloth" \\\n git+https://github.com/triton-lang/triton.git@0add68262ab0a2e33b84524346cb27cbb2787356#subdirectory=python/triton_kernels\nelif importlib.util.find_spec("unsloth") is None:\n !uv pip install -qqq unsloth\n# Gemma 4 requires transformers >= 5.5.0\n!uv pip install --upgrade --no-deps "transformers>=5.5.0" tokenizers "trl>=0.28.0" unsloth unsloth_zoo\n')
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
|
||||
|
||||
# ### Unsloth
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import FastVisionModel
|
||||
import torch
|
||||
max_seq_length = 4096 # Can increase for longer reasoning traces
|
||||
lora_rank = 32 # Larger rank = smarter, but slower
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-E2B-it",
|
||||
max_seq_length = max_seq_length,
|
||||
load_in_4bit = False, # False for LoRA 16bit
|
||||
fast_inference = False, # Enable vllm fast inference
|
||||
)
|
||||
|
||||
|
||||
# To do efficient RL, we will use [LoRA](https://arxiv.org/abs/2106.09685), which allows us to only add 1 to 5% of extra weights to the model for finetuning purposes. This allows us to save memory usage by over 60%, and yet it retains good accuracy.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
lora_alpha = lora_rank*2, # *2 speeds up training
|
||||
use_gradient_checkpointing = "unsloth", # Reduces memory usage
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
|
||||
# # Sudoku Game Implementation
|
||||
#
|
||||
# We use GPT-5 to create a clean Sudoku solver environment. The strategy outputs "row,col,value" to fill cells.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
#@title Sudoku Game Implementation
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple, Optional
|
||||
import random
|
||||
import copy
|
||||
|
||||
def _is_valid_placement(board: List[List[int]], row: int, col: int, num: int) -> bool:
|
||||
"""Check if placing num at (row, col) is valid."""
|
||||
# Check row
|
||||
if num in board[row]:
|
||||
return False
|
||||
|
||||
# Check column
|
||||
if num in [board[r][col] for r in range(9)]:
|
||||
return False
|
||||
|
||||
# Check 3x3 box
|
||||
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
|
||||
for r in range(box_row, box_row + 3):
|
||||
for c in range(box_col, box_col + 3):
|
||||
if board[r][c] == num:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _solve_sudoku(board: List[List[int]]) -> bool:
|
||||
"""Solve sudoku using backtracking (for puzzle generation)."""
|
||||
for row in range(9):
|
||||
for col in range(9):
|
||||
if board[row][col] == 0:
|
||||
for num in range(1, 10):
|
||||
if _is_valid_placement(board, row, col, num):
|
||||
board[row][col] = num
|
||||
if _solve_sudoku(board):
|
||||
return True
|
||||
board[row][col] = 0
|
||||
return False
|
||||
return True
|
||||
|
||||
def _generate_complete_board(rng: random.Random) -> List[List[int]]:
|
||||
"""Generate a complete valid Sudoku board."""
|
||||
board = [[0 for _ in range(9)] for _ in range(9)]
|
||||
|
||||
# Fill diagonal 3x3 boxes first (they don't affect each other)
|
||||
for box in range(3):
|
||||
nums = list(range(1, 10))
|
||||
rng.shuffle(nums)
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
board[box * 3 + i][box * 3 + j] = nums[i * 3 + j]
|
||||
|
||||
# Solve the rest
|
||||
_solve_sudoku(board)
|
||||
return board
|
||||
|
||||
@dataclass
|
||||
class SudokuGame:
|
||||
difficulty: int = 40 # Number of cells to remove (20 = easy, 40 = medium, 50 = hard)
|
||||
seed: Optional[int] = None
|
||||
_rng: random.Random = field(init = False, repr = False)
|
||||
_board: List[List[int]] = field(init = False, repr = False)
|
||||
_solution: List[List[int]] = field(init = False, repr = False)
|
||||
_initial_board: List[List[int]] = field(init = False, repr = False)
|
||||
_moves: int = field(default = 0, init = False, repr = False)
|
||||
_state: str = field(default = "ongoing", init = False, repr = False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._rng = random.Random(self.seed)
|
||||
|
||||
# Generate complete board
|
||||
complete_board = _generate_complete_board(self._rng)
|
||||
self._solution = copy.deepcopy(complete_board)
|
||||
|
||||
# Remove cells to create puzzle
|
||||
self._board = copy.deepcopy(complete_board)
|
||||
cells = [(r, c) for r in range(9) for c in range(9)]
|
||||
self._rng.shuffle(cells)
|
||||
|
||||
for r, c in cells[:self.difficulty]:
|
||||
self._board[r][c] = 0
|
||||
|
||||
self._initial_board = copy.deepcopy(self._board)
|
||||
self._update_state()
|
||||
|
||||
def board(self) -> List[List[int]]:
|
||||
"""Return current board state."""
|
||||
return [row[:] for row in self._board]
|
||||
|
||||
def initial_board(self) -> List[List[int]]:
|
||||
"""Return initial puzzle state."""
|
||||
return [row[:] for row in self._initial_board]
|
||||
|
||||
def state(self) -> str:
|
||||
"""Return game state: 'ongoing', 'success', or 'failed'."""
|
||||
return self._state
|
||||
|
||||
def moves(self) -> int:
|
||||
"""Return number of moves made."""
|
||||
return self._moves
|
||||
|
||||
def place_number(self, row: int, col: int, num: int) -> bool:
|
||||
"""Place a number on the board. Returns True if valid move."""
|
||||
# Validate input
|
||||
if not (0 <= row < 9 and 0 <= col < 9):
|
||||
self._state = "failed"
|
||||
return False
|
||||
|
||||
if not (1 <= num <= 9):
|
||||
self._state = "failed"
|
||||
return False
|
||||
|
||||
# Can't modify initial cells
|
||||
if self._initial_board[row][col] != 0:
|
||||
self._state = "failed"
|
||||
return False
|
||||
if self._board[row][col] != 0:
|
||||
self._state = "failed"
|
||||
return False
|
||||
# Check if placement is valid
|
||||
if not _is_valid_placement(self._board, row, col, num):
|
||||
self._state = "failed"
|
||||
return False
|
||||
|
||||
# Place number
|
||||
self._board[row][col] = num
|
||||
self._moves += 1
|
||||
self._update_state()
|
||||
return True
|
||||
|
||||
def _update_state(self) -> None:
|
||||
"""Update game state based on current board."""
|
||||
# Check if puzzle is complete
|
||||
if all(self._board[r][c] != 0 for r in range(9) for c in range(9)):
|
||||
# Verify solution is correct
|
||||
if self._board == self._solution:
|
||||
self._state = "success"
|
||||
else:
|
||||
self._state = "failed"
|
||||
else:
|
||||
self._state = "ongoing"
|
||||
|
||||
def pretty(self, colors: bool = True) -> str:
|
||||
"""Pretty print the Sudoku board."""
|
||||
RESET = "\x1b[0m"
|
||||
INITIAL = "\x1b[38;5;45m" # Cyan for initial numbers
|
||||
PLACED = "\x1b[38;5;226m" # Yellow for placed numbers
|
||||
EMPTY = "\x1b[38;5;239m" # Gray for empty cells
|
||||
|
||||
lines = []
|
||||
lines.append("┌───────┬───────┬───────┐")
|
||||
|
||||
for row in range(9):
|
||||
row_str = "│ "
|
||||
for col in range(9):
|
||||
num = self._board[row][col]
|
||||
|
||||
if colors:
|
||||
if num == 0:
|
||||
row_str += f"{EMPTY}.{RESET}"
|
||||
elif self._initial_board[row][col] != 0:
|
||||
row_str += f"{INITIAL}{num}{RESET}"
|
||||
else:
|
||||
row_str += f"{PLACED}{num}{RESET}"
|
||||
else:
|
||||
row_str += str(num) if num != 0 else "."
|
||||
|
||||
if col % 3 == 2:
|
||||
row_str += " │ "
|
||||
else:
|
||||
row_str += " "
|
||||
|
||||
lines.append(row_str.rstrip())
|
||||
|
||||
if row == 8:
|
||||
lines.append("└───────┴───────┴───────┘")
|
||||
elif row % 3 == 2:
|
||||
lines.append("├───────┼───────┼───────┤")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Test the Sudoku environment:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Create an easy puzzle
|
||||
game = SudokuGame(difficulty = 30, seed = 42)
|
||||
print("Initial puzzle:")
|
||||
print(game.pretty())
|
||||
print(f"\nState: {game.state()}, Moves: {game.moves()}")
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
game
|
||||
|
||||
|
||||
# Try making some moves:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Make a valid move
|
||||
game.place_number(0, 1, 7)
|
||||
print("\nAfter placing 7 at (1,0):")
|
||||
print(game.pretty())
|
||||
print(f"State: {game.state()}, Moves: {game.moves()}")
|
||||
|
||||
|
||||
# If we do some other action that's not part of the action space, we will get an error, and the game will not accept anymore actions.
|
||||
|
||||
# # RL Environment Setup
|
||||
#
|
||||
# Execute strategies with time limits to prevent infinite loops.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from typing import Callable
|
||||
from unsloth import execute_with_time_limit
|
||||
|
||||
def _execute_strategy(strategy: Callable, game: SudokuGame):
|
||||
"""Execute a strategy function on a Sudoku game."""
|
||||
assert callable(strategy)
|
||||
|
||||
max_moves = 100
|
||||
valid_moves = 0 # Track successful moves
|
||||
|
||||
while game.state() == "ongoing" and valid_moves < max_moves:
|
||||
try:
|
||||
board = game.board()
|
||||
initial = game.initial_board()
|
||||
result = strategy(board, initial)
|
||||
|
||||
# Validate result format
|
||||
if not isinstance(result, (tuple, list)) or len(result) != 3:
|
||||
# Invalid format = immediate fail, but return valid moves made
|
||||
return valid_moves, "failed"
|
||||
|
||||
row, col, num = result
|
||||
|
||||
# Validate types
|
||||
if not all(isinstance(x, int) for x in [row, col, num]):
|
||||
return valid_moves, "failed"
|
||||
|
||||
# Try to place number
|
||||
success = game.place_number(row, col, num)
|
||||
|
||||
if success:
|
||||
valid_moves += 1 # Count this valid move
|
||||
else:
|
||||
# Invalid move = game fails, but return valid_moves made so far
|
||||
return valid_moves, "failed"
|
||||
|
||||
except Exception:
|
||||
return valid_moves, "failed"
|
||||
|
||||
if valid_moves >= max_moves and game.state() == "ongoing":
|
||||
return valid_moves, "failed"
|
||||
|
||||
return valid_moves, game.state()
|
||||
|
||||
|
||||
# To allow longer strategies for Reinforcement Learning, we shall allow a 10 second timer.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
@execute_with_time_limit(10)
|
||||
def execute_strategy(strategy: Callable, game: SudokuGame):
|
||||
"""Execute strategy with 10 second time limit."""
|
||||
return _execute_strategy(strategy, game)
|
||||
|
||||
|
||||
# Test with a simple strategy:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def simple_strategy(board, initial):
|
||||
"""Simple strategy: fill first empty cell with 1."""
|
||||
for r in range(9):
|
||||
for c in range(9):
|
||||
if board[r][c] == 0 and initial[r][c] == 0:
|
||||
return (r, c, 7)
|
||||
return (0, 0, 7)
|
||||
|
||||
game = SudokuGame(difficulty = 30, seed = 42)
|
||||
try:
|
||||
moves, state = execute_strategy(simple_strategy, game)
|
||||
print(f"Moves: {moves}, State: {state}")
|
||||
except TimeoutError as e:
|
||||
print(f"Timed out: {e}")
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
print(game.pretty())
|
||||
|
||||
|
||||
# # Code Execution
|
||||
#
|
||||
# To execute and create a new Python function, we first have to check if the function does not call other global variables or cheat. This is called `countering reward hacking` since we don't want the function to cheat.
|
||||
#
|
||||
# For example the below piece of code is fine, since it only imports Python level functions. We use `check_python_modules`:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from unsloth import check_python_modules, create_locked_down_function
|
||||
|
||||
# Test safe code
|
||||
sample = """
|
||||
def strategy(board, initial):
|
||||
for r in range(9):
|
||||
for c in range(9):
|
||||
if board[r][c] == 0:
|
||||
return (r, c, 1)
|
||||
return (0, 0, 1)
|
||||
"""
|
||||
|
||||
ok, info = check_python_modules(sample)
|
||||
print("Safe Python code?", ok)
|
||||
print(info)
|
||||
|
||||
|
||||
# For the below piece of code, since we import `numpy`, we should not allow the execution:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
sample = """
|
||||
def strategy(board, initial):
|
||||
import numpy as np
|
||||
return (0, 0, 1)
|
||||
"""
|
||||
|
||||
ok, info = check_python_modules(sample)
|
||||
print("Safe Python code?", ok)
|
||||
print(info)
|
||||
|
||||
|
||||
# # Data & RL task setup
|
||||
#
|
||||
# Create the prompt that instructs the model to generate a Sudoku solving strategy. You can customize this to some other task for another RL task.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
prompt = """
|
||||
Create a Sudoku solving strategy using only native Python built-in functions without any import statements.
|
||||
You are given two lists of lists (9x9 grids):
|
||||
- board: current state (0 means empty)
|
||||
- initial: starting puzzle (0 means was empty, numbers are fixed)
|
||||
|
||||
Return a tuple (row, col, number) for the next move.
|
||||
- row: 0-8 (row index)
|
||||
- col: 0-8 (column index)
|
||||
- number: 1-9 (digit to place)
|
||||
|
||||
Only place numbers in cells that are BOTH empty in initial AND empty in board (initial[row][col] == 0 AND board[row][col] == 0)
|
||||
Use Sudoku rules: no duplicates in rows, columns, or 3x3 boxes.
|
||||
Output your function in backticks:
|
||||
```python
|
||||
def strategy(board, initial):
|
||||
# Your logic here
|
||||
return (row, col, number)
|
||||
```
|
||||
All helper functions must be inside def strategy. Output only the function.
|
||||
""".strip()
|
||||
|
||||
print(prompt)
|
||||
|
||||
|
||||
# First, let's prompt the model without RL and see how it goes:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt.strip()}],
|
||||
tokenize = False,
|
||||
add_generation_prompt = True,
|
||||
)
|
||||
|
||||
from transformers import TextStreamer
|
||||
print("=" * 50)
|
||||
print("BASE MODEL OUTPUT (before RL training):")
|
||||
print("=" * 50)
|
||||
|
||||
inputs = tokenizer(
|
||||
text = text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# # Reward functions
|
||||
#
|
||||
# We now design a `extract_function` function which simply extracts the function wrapped in 3 back ticks.
|
||||
#
|
||||
# And 3 reward functions:
|
||||
#
|
||||
# 1. `function_works` which rewards the model if the strategy is a valid Python function.
|
||||
# 2. `no_cheating` which checks if the function imported other modules, and if it did, we penalize it.
|
||||
# 3. `strategy_succeeds` which checks if the game strategy actually succeeds in attaining Sudoku after running the auto-generated strategy.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def extract_function(text):
|
||||
"""Extract Python function from markdown code blocks."""
|
||||
if text.count("```") >= 2:
|
||||
first = text.find("```") + 3
|
||||
second = text.find("```", first)
|
||||
fx = text[first:second].strip()
|
||||
fx = fx.removeprefix("python\n")
|
||||
fx = fx[fx.find("def"):]
|
||||
if fx.startswith("def strategy(board, initial):"):
|
||||
return fx
|
||||
return None
|
||||
|
||||
|
||||
# **Reward 1: Function Works**
|
||||
#
|
||||
# Checks if the generated code is valid Python and can be executed.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def function_works(completions, **kwargs):
|
||||
"""Reward for generating valid executable Python code."""
|
||||
scores = []
|
||||
for completion in completions:
|
||||
score = 0
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
|
||||
if function is not None:
|
||||
ok, info = check_python_modules(function)
|
||||
|
||||
if function is None or "error" in info:
|
||||
score = -2.0 # Invalid function
|
||||
else:
|
||||
try:
|
||||
new_strategy = create_locked_down_function(function)
|
||||
score = 1.0 # Valid function
|
||||
except:
|
||||
score = -1.0 # Function has errors
|
||||
|
||||
scores.append(score)
|
||||
return scores
|
||||
|
||||
|
||||
# **Reward 2: No Cheating**
|
||||
#
|
||||
# Penalizes functions that import external libraries.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def no_cheating(completions, **kwargs):
|
||||
"""Penalize use of external imports."""
|
||||
scores = []
|
||||
for completion in completions:
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
|
||||
if function is not None:
|
||||
ok, info = check_python_modules(function)
|
||||
scores.append(1.0 if ok else -20.0) # Heavy penalty for cheating
|
||||
else:
|
||||
scores.append(-1.0) # Failed to create function
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
# **Reward 3: Strategy Succeeds**
|
||||
#
|
||||
# Rewards strategies that successfully solve Sudoku puzzles.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
global PRINTER
|
||||
PRINTER = 0
|
||||
|
||||
def strategy_succeeds(completions, **kwargs):
|
||||
"""Reward valid moves even if strategy eventually fails."""
|
||||
global PRINTER
|
||||
scores = []
|
||||
|
||||
seed = np.random.randint(10000)
|
||||
difficulty = 40
|
||||
for completion in completions:
|
||||
printed = False
|
||||
response = completion[0]["content"]
|
||||
function = extract_function(response)
|
||||
|
||||
if PRINTER % 5 == 0:
|
||||
printed = True
|
||||
print("\n" + "=" * 60)
|
||||
print(function)
|
||||
print("=" * 60)
|
||||
PRINTER += 1
|
||||
|
||||
if function is not None:
|
||||
ok, info = check_python_modules(function)
|
||||
|
||||
if function is None or "error" in info:
|
||||
scores.append(0)
|
||||
continue
|
||||
|
||||
try:
|
||||
new_strategy = create_locked_down_function(function)
|
||||
except:
|
||||
scores.append(0)
|
||||
continue
|
||||
|
||||
try:
|
||||
game = SudokuGame(difficulty = difficulty, seed = seed)
|
||||
valid_moves, game_state = execute_strategy(new_strategy, game)
|
||||
if valid_moves == difficulty:
|
||||
game_state = "success"
|
||||
|
||||
print(f"\n Valid moves: {valid_moves}, Final state: {game_state}")
|
||||
|
||||
if not printed:
|
||||
print("Strategy:")
|
||||
print(function[:200] + "..." if len(function) > 200 else function)
|
||||
|
||||
print("\nFinal board:")
|
||||
print(game.pretty())
|
||||
|
||||
if game_state == "success":
|
||||
scores.append(30.0) # Solved the puzzle!
|
||||
elif valid_moves > 0:
|
||||
# Reward based on valid moves made before failure
|
||||
# Each valid move is worth 0.2 points
|
||||
reward = valid_moves * 0.2
|
||||
scores.append(reward)
|
||||
else:
|
||||
scores.append(-2.0) # Failed immediately with no valid moves
|
||||
|
||||
except TimeoutError:
|
||||
print("Timeout")
|
||||
scores.append(-1.0)
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)[:100]}")
|
||||
scores.append(-3.0)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
# # Dataset Preparation
|
||||
#
|
||||
# Create the training dataset.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_list([
|
||||
{
|
||||
"prompt": [{"role": "user", "content": prompt.strip()}],
|
||||
"answer": 0,
|
||||
}
|
||||
] * 1000)
|
||||
|
||||
maximum_length = len(tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt.strip()}],
|
||||
add_generation_prompt = True
|
||||
))
|
||||
|
||||
print(f"Maximum prompt length: {maximum_length}")
|
||||
print("\nDataset sample:")
|
||||
print(dataset[0])
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
#
|
||||
# Now set up GRPO Trainer and all configurations! We also support GSPO, GAPO, Dr GRPO and more! Go the Unsloth [Reinforcement Learning Docs](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) for more options.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Leave room for the prompt (plus 1 token safety margin)
|
||||
max_completion_length = max_seq_length - (maximum_length + 1)
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
training_args = GRPOConfig(
|
||||
temperature = 1.0,
|
||||
learning_rate = 5e-5,
|
||||
weight_decay = 0.001,
|
||||
warmup_ratio = 0.1,
|
||||
lr_scheduler_type = "linear",
|
||||
optim = "adamw_8bit",
|
||||
logging_steps = 1,
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 2, # Increase to 4 for smoother training
|
||||
num_generations = 2, # Decrease if out of memory
|
||||
max_completion_length = max_completion_length,
|
||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
||||
max_steps = 60,
|
||||
save_steps = 100,
|
||||
report_to = "none", # Can use Weights & Biases, TrackIO
|
||||
output_dir = "outputs",
|
||||
epsilon = 0.2,
|
||||
epsilon_high = 0.28, # one sided
|
||||
delta = 1.5, # two sided
|
||||
loss_type = 'bnpo',
|
||||
mask_truncated_completions = True
|
||||
# For optional training + evaluation
|
||||
# fp16_full_eval = True,
|
||||
# per_device_eval_batch_size = 4,
|
||||
# eval_accumulation_steps = 1,
|
||||
# eval_strategy = "steps",
|
||||
# eval_steps = 1,
|
||||
)
|
||||
|
||||
|
||||
# And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!
|
||||
#
|
||||
# You might have to wait 150 to 200 steps for any action. You'll probably get low reward for the first 100 steps. Please be patient!
|
||||
#
|
||||
# | Step | Training Loss | reward | reward_std | completion_length | kl |
|
||||
# |------|---------------|-----------|------------|-------------------|----------|
|
||||
# | 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |
|
||||
# | 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |
|
||||
# | 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# For optional training + evaluation
|
||||
# new_dataset = dataset.train_test_split(test_size = 0.01)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model = model,
|
||||
processing_class = tokenizer,
|
||||
reward_funcs = [
|
||||
function_works,
|
||||
no_cheating,
|
||||
strategy_succeeds,
|
||||
],
|
||||
args = training_args,
|
||||
train_dataset = dataset,
|
||||
|
||||
# For optional training + evaluation
|
||||
# train_dataset = new_dataset["train"],
|
||||
# eval_dataset = new_dataset["test"],
|
||||
)
|
||||
|
||||
|
||||
# And let's train the model!
|
||||
#
|
||||
# **NOTE** A T4 free GPU might take 5 minutes for one generation sadly since it's an old GPU - A100 or H100 will be much faster!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
# And now with the LoRA we just trained with GRPO - we first save the LoRA first!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
|
||||
|
||||
# Verify LoRA is actually trained!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
tensors = {}
|
||||
with safe_open("grpo_saved_lora/adapter_model.safetensors", framework = "pt") as f:
|
||||
# Verify both A and B are non zero
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
n_zeros = (tensor == 0).sum() / tensor.numel()
|
||||
assert(n_zeros.item() != tensor.numel())
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# # Inference
|
||||
# Now let's try the model we just trained!
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt.strip()}],
|
||||
tokenize = False,
|
||||
add_generation_prompt = True,
|
||||
)
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
_ = model.generate(
|
||||
**tokenizer(images = None,text = text, return_tensors = "pt").to("cuda"),
|
||||
temperature = 1.0,
|
||||
max_new_tokens = 512,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = False),
|
||||
)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options.
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Merge to 16bit
|
||||
if False: model.save_pretrained_merged("gemma_4_finetune_16bit", tokenizer, save_method = "merged_16bit",)
|
||||
if False: model.push_to_hub_merged("HF_USERNAME/gemma_4_finetune_16bit", tokenizer, save_method = "merged_16bit", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Merge to 4bit
|
||||
if False: model.save_pretrained_merged("gemma_4_finetune_4bit", tokenizer, save_method = "merged_4bit",)
|
||||
if False: model.push_to_hub_merged("HF_USERNAME/gemma_4_finetune_4bit", tokenizer, save_method = "merged_4bit", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Just LoRA adapters
|
||||
if False:
|
||||
model.save_pretrained("gemma_4_lora")
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
if False:
|
||||
model.push_to_hub("HF_USERNAME/gemma_4_lora", token = "YOUR_HF_TOKEN")
|
||||
tokenizer.push_to_hub("HF_USERNAME/gemma_4_lora", token = "YOUR_HF_TOKEN")
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.
|
||||
#
|
||||
# Some supported quant methods (full list on our [docs page](https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf)):
|
||||
# * `q8_0` - Fast conversion. High resource use, but generally acceptable.
|
||||
# * `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
|
||||
# * `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.
|
||||
#
|
||||
# [**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
# Save to 8bit Q8_0
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer,)
|
||||
# Remember to go to https://huggingface.co/settings/tokens for a token!
|
||||
# And change hf to your username!
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to 16bit GGUF
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer, quantization_method = "f16")
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, quantization_method = "f16", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to q4_k_m GGUF
|
||||
if False: model.save_pretrained_gguf("gemma_4_finetune", tokenizer, quantization_method = "q4_k_m")
|
||||
if False: model.push_to_hub_gguf("HF_USERNAME/gemma_4_finetune", tokenizer, quantization_method = "q4_k_m", token = "YOUR_HF_TOKEN")
|
||||
|
||||
# Save to multiple GGUF options - much faster if you want multiple!
|
||||
if False:
|
||||
model.push_to_hub_gguf(
|
||||
"HF_USERNAME/gemma_4_finetune", # Change hf to your username!
|
||||
tokenizer,
|
||||
quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma_4_finetune.Q8_0.gguf` file or `gemma_4_finetune.Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,478 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
#
|
||||
# `FastModel` supports loading nearly any model now! This includes Vision and Text models!
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastModel
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
fourbit_models = [
|
||||
# Gemma 4 models
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, processor = FastModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-E4B-it",
|
||||
dtype = None, # None for auto detection
|
||||
max_seq_length = 8192, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
# token = "YOUR_HF_TOKEN", # HF Token for gated models
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can process Text, Vision and Audio!
|
||||
#
|
||||
# Let's first experience how Gemma 4 can handle multimodal inputs. We use Gemma 4's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64` but for this example we use `do_sample=False` for ASR.
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
from transformers import TextStreamer
|
||||
# Helper function for inference
|
||||
def do_gemma_4_inference(messages, max_new_tokens = 128):
|
||||
_ = model.generate(
|
||||
**processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
return_tensors = "pt",
|
||||
).to("cuda"),
|
||||
max_new_tokens = max_new_tokens,
|
||||
do_sample = False,
|
||||
streamer = TextStreamer(processor, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# <h3>Let's Evaluate Gemma 4 Baseline Performance on German Transcription</h2>
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
from datasets import load_dataset,Audio,concatenate_datasets
|
||||
|
||||
dataset = load_dataset("kadirnar/Emilia-DE-B000000", split = "train")
|
||||
|
||||
# Select a single audio sample to reserve for testing.
|
||||
# This index is chosen from the full dataset before we create the smaller training split.
|
||||
test_audio = dataset[7546]
|
||||
|
||||
dataset = dataset.select(range(3000))
|
||||
|
||||
dataset = dataset.cast_column("audio", Audio(sampling_rate = 16000))
|
||||
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
from IPython.display import Audio, display
|
||||
print(test_audio['text'])
|
||||
Audio(test_audio['audio']['array'],rate = test_audio['audio']['sampling_rate'])
|
||||
|
||||
|
||||
# And the translation of the audio from German to English is:
|
||||
#
|
||||
# > I—I hold myself directly accountable. That much is, of course, clear: namely, that there are political interests involved in trade—in the exchange of goods—and that political influences are at play. The question is: that should not be the alternative.
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": test_audio['audio']['array']},
|
||||
{"type": "text", "text": "Please transcribe this audio."}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# <h3>Baseline Model Performance: 32.43% Word Error Rate (WER) for this sample !</h3>
|
||||
|
||||
# # Let's finetune Gemma 4!
|
||||
#
|
||||
# You can finetune the vision and text and audio parts
|
||||
|
||||
# We now add LoRA adapters so we only need to update a small amount of parameters!
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = False, # False if not finetuning vision layers
|
||||
finetune_language_layers = True, # False if not finetuning language layers
|
||||
finetune_attention_modules = True, # False if not finetuning attention layers
|
||||
finetune_mlp_modules = True, # False if not finetuning MLP layers
|
||||
|
||||
r = 8, # The larger, the higher the accuracy, but might overfit
|
||||
lora_alpha = 16, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
target_modules = [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
|
||||
# Audio layers
|
||||
"post", "linear_start", "linear_end",
|
||||
"embedding_projection",
|
||||
"ffw_layer_1", "ffw_layer_2",
|
||||
"output_proj",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We adapt the `kadirnar/Emilia-DE-B000000` dataset for our German ASR task using Gemma 4 multi-modal chat format. Each audio-text pair is structured into a conversation with `system`, `user`, and `assistant` roles. The processor then converts this into the final training format:
|
||||
#
|
||||
# ```
|
||||
# <bos><|turn>system
|
||||
# You are an assistant that transcribes speech accurately.<turn|>
|
||||
# <|turn>user
|
||||
# <|audio|>Please transcribe this audio.<turn|>
|
||||
# <|turn>model
|
||||
# Ich, ich rechne direkt mich an.<turn|>
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
def format_intersection_data(samples: dict) -> dict[str, list]:
|
||||
"""Format intersection dataset to match expected message format"""
|
||||
formatted_samples = {"messages": []}
|
||||
for idx in range(len(samples["audio"])):
|
||||
audio = samples["audio"][idx]["array"]
|
||||
label = str(samples["text"][idx])
|
||||
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": audio},
|
||||
{"type": "text", "text": "Please transcribe this audio."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content":[{"type": "text", "text": label}]
|
||||
}
|
||||
]
|
||||
formatted_samples["messages"].append(message)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
dataset = dataset.map(format_intersection_data, batched = True, batch_size = 4, num_proc = 4)
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
# Use UnslothVisionDataCollator which handles audio token alignment correctly
|
||||
from unsloth.trainer import UnslothVisionDataCollator
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
train_dataset = dataset,
|
||||
processing_class = processor.tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, processor),
|
||||
args = SFTConfig(
|
||||
per_device_train_batch_size = 8,
|
||||
gradient_accumulation_steps = 1,
|
||||
warmup_ratio = 0.03,
|
||||
# num_train_epochs = 1, # Use for full training runs
|
||||
max_steps = 60,
|
||||
learning_rate = 5e-5,
|
||||
logging_steps = 1,
|
||||
save_strategy = "steps",
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "cosine",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
remove_unused_columns = False,
|
||||
|
||||
# The below are a must for audio finetuning:
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
max_length = 8192,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# # Let's train the model!
|
||||
#
|
||||
# To resume a training run, set `trainer.train(resume_from_checkpoint = True)`
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model via Unsloth native inference! According to the `Gemma-4` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64` but for this example we use `do_sample=False` for ASR.
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": test_audio['audio']['array']},
|
||||
{"type": "text", "text": "Please transcribe this audio."}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, either use Hugging Face's `push_to_hub` for an online save or `save_pretrained` for a local save.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
processor.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# processor.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastModel
|
||||
model, processor = FastModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "What is Gemma-4?",}]
|
||||
}]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 128, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(processor, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly for deployment! We save it in the folder `gemma-4-finetune`. Set `if False` to `if True` to let it run!
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
if False: # Change to True to save finetune!
|
||||
model.save_pretrained_merged("gemma-4", processor)
|
||||
|
||||
|
||||
# If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
if False: # Change to True to upload finetune
|
||||
model.push_to_hub_merged(
|
||||
"HF_ACCOUNT/gemma-4-finetune", processor,
|
||||
token = "YOUR_HF_TOKEN"
|
||||
)
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
if False: # Change to True to save to GGUF
|
||||
model.save_pretrained_gguf(
|
||||
"gemma_4_finetune",
|
||||
processor,
|
||||
quantization_method = "Q8_0", # For now only Q8_0, BF16, F16 supported
|
||||
)
|
||||
|
||||
|
||||
# Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
if False: # Change to True to upload GGUF
|
||||
model.push_to_hub_gguf(
|
||||
"HF_ACCOUNT/gemma_4_finetune",
|
||||
processor,
|
||||
quantization_method = "Q8_0", # Only Q8_0, BF16, F16 supported
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma-4-finetune.gguf` file or `gemma-4-finetune-Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,557 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a Google Colab L4 instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
#
|
||||
# `FastModel` supports loading nearly any model now! This includes Vision and Text models!
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastModel
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "unsloth/gemma-4-E4B-it",
|
||||
dtype = None, # None for auto detection
|
||||
max_seq_length = 1024, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
# token = "YOUR_HF_TOKEN", # HF Token for gated models
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can process Text, Vision and Audio!
|
||||
#
|
||||
# Let's first experience how Gemma 4 can handle multimodal inputs. We use Gemma 4's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
from transformers import TextStreamer
|
||||
# Helper function for inference
|
||||
def do_gemma_4_inference(messages, max_new_tokens = 128):
|
||||
_ = model.generate(
|
||||
**tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
return_tensors = "pt",
|
||||
).to("cuda"),
|
||||
max_new_tokens = max_new_tokens,
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
use_cache = True
|
||||
)
|
||||
|
||||
|
||||
# # Gemma 4 can see images!
|
||||
#
|
||||
# <img src="https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg" alt="Alt text" height="256">
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "image", "image" : sloth_link },
|
||||
{ "type": "text", "text" : "Which films does this animal feature in?" }
|
||||
]
|
||||
}]
|
||||
# You might have to wait 1 minute for Unsloth's auto compiler
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# Let's make a poem about sloths!
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{ "type" : "text",
|
||||
"text" : "Write a poem about sloths." }]
|
||||
}]
|
||||
do_gemma_4_inference(messages)
|
||||
|
||||
|
||||
# # Gemma 4 can also hear!
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
from IPython.display import Audio, display
|
||||
Audio("https://www.nasa.gov/wp-content/uploads/2015/01/591240main_JFKmoonspeech.mp3")
|
||||
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
get_ipython().system('wget -qqq https://www.nasa.gov/wp-content/uploads/2015/01/591240main_JFKmoonspeech.mp3 -O audio.mp3')
|
||||
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
audio_file = "audio.mp3"
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "audio", "audio" : audio_file },
|
||||
{ "type": "text", "text" : "What is this audio about?" }
|
||||
]
|
||||
}]
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# # Let's combine all 3 modalities together!
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role" : "user",
|
||||
"content": [
|
||||
{ "type": "audio", "audio" : audio_file },
|
||||
{ "type": "image", "image" : sloth_link },
|
||||
{ "type": "text", "text" : "What is this audio and image about? "\
|
||||
"How are they related?" }
|
||||
]
|
||||
}]
|
||||
do_gemma_4_inference(messages, max_new_tokens = 256)
|
||||
|
||||
|
||||
# # Let's finetune Gemma 4!
|
||||
#
|
||||
# You can finetune the vision and text parts for now through selection - the audio part can also be finetuned - we're working to make it selectable as well!
|
||||
|
||||
# We now add LoRA adapters so we only need to update a small amount of parameters!
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = False, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # Should leave on always!
|
||||
|
||||
r = 8, # Larger = higher accuracy, but might overfit
|
||||
lora_alpha = 8, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We now use the `Gemma-4` format for conversation style finetunes. We use [Maxime Labonne's FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset in ShareGPT style. Gemma-4 renders multi turn conversations like below:
|
||||
#
|
||||
# ```
|
||||
# <bos><|turn>user
|
||||
# Hello<turn|>
|
||||
# <|turn>model
|
||||
# Hey there!<turn|>
|
||||
# ```
|
||||
# We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3, phi4, qwen2.5, gemma3, gemma-4` and more.
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4",
|
||||
)
|
||||
|
||||
|
||||
# We get the first 3000 rows of the dataset
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("mlabonne/FineTome-100k", split = "train[:3000]")
|
||||
|
||||
|
||||
# We now use `standardize_data_formats` to try converting datasets to the correct format for finetuning purposes!
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import standardize_data_formats
|
||||
dataset = standardize_data_formats(dataset)
|
||||
|
||||
|
||||
# Let's see how row 100 looks like!
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
dataset[100]
|
||||
|
||||
|
||||
# We now have to apply the chat template for `Gemma-4` onto the conversations, and save it to `text`. We remove the `<bos>` token using removeprefix(`'<bos>'`) since we're finetuning. The Processor will add this token before training and the model expects only one.
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversations"]
|
||||
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
|
||||
return { "text" : texts, }
|
||||
|
||||
dataset = dataset.map(formatting_prompts_func, batched = True)
|
||||
|
||||
|
||||
# Let's see how the chat template did! Notice there is no `<bos>` token as the processor tokenizer will be adding one.
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
dataset[100]["text"]
|
||||
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset,
|
||||
eval_dataset = None, # Can set up evaluation!
|
||||
args = SFTConfig(
|
||||
dataset_text_field = "text",
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4, # Use GA to mimic batch size!
|
||||
warmup_steps = 5,
|
||||
# num_train_epochs = 1, # Set this for 1 full training run.
|
||||
max_steps = 60,
|
||||
learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
|
||||
logging_steps = 1,
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
report_to = "none", # Use TrackIO/WandB etc
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import train_on_responses_only
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|turn>user\n",
|
||||
response_part = "<|turn>model\n",
|
||||
)
|
||||
|
||||
|
||||
# Let's verify masking the instruction part is done! Let's print the 100th row again. Notice how the sample only has a single `<bos>` as expected!
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
tokenizer.decode(trainer.train_dataset[100]["input_ids"])
|
||||
|
||||
|
||||
# Now let's print the masked out example - you should see only the answer is present:
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")
|
||||
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# # Let's train the model!
|
||||
#
|
||||
# To resume a training run, set `trainer.train(resume_from_checkpoint = True)`
|
||||
|
||||
# In[23]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[24]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model via Unsloth native inference! According to the `Gemma-4` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`
|
||||
|
||||
# In[25]:
|
||||
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "gemma-4",
|
||||
)
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type" : "text",
|
||||
"text" : "Continue the sequence: 1, 1, 2, 3, 5, 8,",
|
||||
}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
)
|
||||
tokenizer.batch_decode(outputs)
|
||||
|
||||
|
||||
# You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!
|
||||
|
||||
# In[26]:
|
||||
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "Why is the sky blue?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 64, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, either use Hugging Face's `push_to_hub` for an online save or `save_pretrained` for a local save.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[27]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
tokenizer.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# tokenizer.push_to_hub("HF_ACCOUNT/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[28]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastModel
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type" : "text", "text" : "What is Gemma-4?",}]
|
||||
}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
_ = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens = 128, # Increase for longer outputs!
|
||||
# Recommended Gemma-4 settings!
|
||||
temperature = 1.0, top_p = 0.95, top_k = 64,
|
||||
streamer = TextStreamer(tokenizer, skip_prompt = True),
|
||||
)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly for deployment! We save it in the folder `gemma-4-finetune`. Set `if False` to `if True` to let it run!
|
||||
|
||||
# In[29]:
|
||||
|
||||
|
||||
if False: # Change to True to save finetune!
|
||||
model.save_pretrained_merged("gemma-4-finetune", tokenizer)
|
||||
|
||||
|
||||
# If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[30]:
|
||||
|
||||
|
||||
if False: # Change to True to upload finetune
|
||||
model.push_to_hub_merged(
|
||||
"HF_ACCOUNT/gemma-4-finetune", tokenizer,
|
||||
token = "YOUR_HF_TOKEN"
|
||||
)
|
||||
|
||||
|
||||
# ### GGUF / llama.cpp Conversion
|
||||
# To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!
|
||||
|
||||
# In[31]:
|
||||
|
||||
|
||||
if False: # Change to True to save to GGUF
|
||||
model.save_pretrained_gguf(
|
||||
"gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # For now only Q8_0, BF16, F16 supported
|
||||
)
|
||||
|
||||
|
||||
# Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!
|
||||
|
||||
# In[32]:
|
||||
|
||||
|
||||
if False: # Change to True to upload GGUF
|
||||
model.push_to_hub_gguf(
|
||||
"HF_ACCOUNT/gemma_4_finetune",
|
||||
tokenizer,
|
||||
quantization_method = "Q8_0", # Only Q8_0, BF16, F16 supported
|
||||
token = "YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
|
||||
# Now, use the `gemma-4-finetune.gguf` file or `gemma-4-finetune-Q4_K_M.gguf` file in llama.cpp.
|
||||
#
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# To run this, press "*Runtime*" and press "*Run all*" on a Google Colab L4 instance!
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
|
||||
# </div>
|
||||
#
|
||||
# To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
#
|
||||
# You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it
|
||||
|
||||
# ### News
|
||||
|
||||
# Introducing **Unsloth Studio** - a new open source, no-code web UI to train and run LLMs. [Blog](https://unsloth.ai/docs/new/studio) • [Notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb)
|
||||
#
|
||||
# <table><tr>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FxV1PO5DbF3ksB51nE2Tw%252Fmore%2520cropped%2520ui%2520for%2520homepage.png%3Falt%3Dmedia%26token%3Df75942c9-3d8d-4b59-8ba2-1a4a38de1b86&width=376&dpr=3&quality=100&sign=a663c397&sv=2" width="200" height="120" alt="Unsloth Studio Training UI"></a><br><sub><b>Train models</b> — no code needed</sub></td>
|
||||
# <td align="center"><a href="https://unsloth.ai/docs/new/studio"><img src="https://unsloth.ai/docs/~gitbook/image?url=https%3A%2F%2F3215535692-files.gitbook.io%2F~%2Ffiles%2Fv0%2Fb%2Fgitbook-x-prod.appspot.com%2Fo%2Fspaces%252FxhOjnexMCB3dmuQFQ2Zq%252Fuploads%252FRCnTAZ6Uh88DIlU3g0Ij%252Fmainpage%2520unsloth.png%3Falt%3Dmedia%26token%3D837c96b6-bd09-4e81-bc76-fa50421e9bfb&width=376&dpr=3&quality=100&sign=c1a39da1&sv=2" width="200" height="120" alt="Unsloth Studio Chat UI"></a><br><sub><b>Run GGUF models</b> on Mac, Windows & Linux</sub></td>
|
||||
# </tr></table>
|
||||
#
|
||||
# Train MoEs - DeepSeek, GLM, Qwen and gpt-oss 12x faster with 35% less VRAM. [Blog](https://unsloth.ai/docs/new/faster-moe)
|
||||
#
|
||||
# Ultra Long-Context Reinforcement Learning is here with 7x more context windows! [Blog](https://unsloth.ai/docs/new/grpo-long-context)
|
||||
#
|
||||
# New in Reinforcement Learning: [FP8 RL](https://unsloth.ai/docs/new/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/new/vision-reinforcement-learning-vlm-rl) • [Standby](https://unsloth.ai/docs/basics/memory-efficient-rl) • [gpt-oss RL](https://unsloth.ai/docs/new/gpt-oss-reinforcement-learning)
|
||||
#
|
||||
# Visit our docs for all our [model uploads](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks).
|
||||
|
||||
# # ### Installation
|
||||
#
|
||||
# # In[1]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', 'import os, re\nif "COLAB_" not in "".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r\'[\\d]{1,}\\.[\\d]{1,}\', str(torch.__version__)).group(0)\n xformers = \'xformers==\' + {\'2.10\':\'0.0.34\',\'2.9\':\'0.0.33.post1\',\'2.8\':\'0.0.32.post2\'}.get(v, "0.0.34")\n !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;\n')
|
||||
#
|
||||
#
|
||||
# # In[2]:
|
||||
#
|
||||
#
|
||||
# get_ipython().run_cell_magic('capture', '', '!pip install --no-deps --upgrade timm # For Gemma 4 vision/audio\n')
|
||||
#
|
||||
#
|
||||
# # ### Unsloth
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
from unsloth import FastVisionModel # FastLanguageModel for LLMs
|
||||
import torch
|
||||
|
||||
gemma4_models = [
|
||||
# Gemma-4 instruct models:
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
# Gemma-4 base models:
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
|
||||
)
|
||||
|
||||
|
||||
# We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.
|
||||
#
|
||||
# **[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers = True, # False if not finetuning vision layers
|
||||
finetune_language_layers = True, # False if not finetuning language layers
|
||||
finetune_attention_modules = True, # False if not finetuning attention layers
|
||||
finetune_mlp_modules = True, # False if not finetuning MLP layers
|
||||
|
||||
r = 32, # The larger, the higher the accuracy, but might overfit
|
||||
lora_alpha = 32, # Recommended alpha == r at least
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
target_modules = "all-linear", # Optional now! Can specify a list if needed
|
||||
)
|
||||
|
||||
|
||||
# <a name="Data"></a>
|
||||
# ### Data Prep
|
||||
# We'll use a sampled dataset of handwritten math formulas. The objective is to convert these images into a computer-readable format—specifically LaTeX—so they can be rendered. This is particularly useful for complex expressions.
|
||||
#
|
||||
# You can access the dataset [here](https://huggingface.co/datasets/unsloth/LaTeX_OCR). The full dataset is [here](https://huggingface.co/datasets/linxy/LaTeX_OCR).
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")
|
||||
|
||||
|
||||
# Let's take an overview of the dataset. We'll examine the second image and its corresponding caption.
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
dataset
|
||||
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
dataset[2]["image"]
|
||||
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
dataset[2]["text"]
|
||||
|
||||
|
||||
# We can also render LaTeX directly in the browser!
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
from IPython.display import display, Math, Latex
|
||||
|
||||
latex = dataset[3]["text"]
|
||||
display(Math(latex))
|
||||
|
||||
|
||||
# To format the dataset, all vision fine-tuning tasks should follow this format:
|
||||
#
|
||||
# ```python
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": instruction},
|
||||
# {"type": "image", "image": sample["image"]},
|
||||
# ],
|
||||
# },
|
||||
# ]
|
||||
# ```
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
def convert_to_conversation(sample):
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": instruction},
|
||||
{"type": "image", "image": sample["image"]},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
|
||||
]
|
||||
return {"messages": conversation}
|
||||
pass
|
||||
|
||||
|
||||
# Let's convert the dataset into the "correct" format for finetuning:
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
converted_dataset = [convert_to_conversation(sample) for sample in dataset]
|
||||
|
||||
|
||||
# The first example is now structured like below:
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
converted_dataset[0]
|
||||
|
||||
|
||||
# Lets take the Gemma 4 instruction chat template and use it in our base model
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
from unsloth import get_chat_template
|
||||
|
||||
processor = get_chat_template(
|
||||
processor,
|
||||
"gemma-4"
|
||||
)
|
||||
|
||||
|
||||
# Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before.
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
image = dataset[2]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# You can see it's absolutely terrible! It doesn't follow instructions at all
|
||||
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
# Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support `DPOTrainer` and `GRPOTrainer` for reinforcement learning!
|
||||
#
|
||||
# We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup.
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
from unsloth.trainer import UnslothVisionDataCollator
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
train_dataset = converted_dataset,
|
||||
processing_class = processor.tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, processor),
|
||||
args = SFTConfig(
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4,
|
||||
max_grad_norm = 0.3,
|
||||
warmup_ratio = 0.03,
|
||||
max_steps = 60,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
learning_rate = 2e-4,
|
||||
logging_steps = 1,
|
||||
save_strategy = "steps",
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.001,
|
||||
lr_scheduler_type = "cosine",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none", # For Weights and Biases or others
|
||||
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
max_length = 2048,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
# @title Show current memory stats
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
||||
print(f"{start_gpu_memory} GB of memory reserved.")
|
||||
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
# @title Show final memory and time stats
|
||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
||||
used_percentage = round(used_memory / max_memory * 100, 3)
|
||||
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
||||
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||
print(
|
||||
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
|
||||
)
|
||||
print(f"Peak reserved memory = {used_memory} GB.")
|
||||
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
||||
|
||||
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Let's run the model! You can modify the instruction and input—just leave the output blank.
|
||||
#
|
||||
# We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`.
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
image = dataset[10]["image"]
|
||||
instruction = "Write the LaTeX representation for this image."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": instruction}],
|
||||
}
|
||||
]
|
||||
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor, skip_prompt = True)
|
||||
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# <a name="Save"></a>
|
||||
# ### Saving, loading finetuned models
|
||||
# To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.
|
||||
#
|
||||
# **[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
model.save_pretrained("gemma_4_lora") # Local saving
|
||||
processor.save_pretrained("gemma_4_lora")
|
||||
# model.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
# processor.push_to_hub("your_name/gemma_4_lora", token = "YOUR_HF_TOKEN") # Online saving
|
||||
|
||||
|
||||
# Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
if False:
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
model_name = "gemma_4_lora", # YOUR MODEL YOU USED FOR TRAINING
|
||||
load_in_4bit = True, # Set to False for 16bit LoRA
|
||||
)
|
||||
|
||||
sample = dataset[1]
|
||||
image = sample["image"].convert("RGB")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sample["text"],
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
input_text = processor.apply_chat_template(messages, add_generation_prompt = True)
|
||||
inputs = processor(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
from transformers import TextStreamer
|
||||
|
||||
text_streamer = TextStreamer(processor.tokenizer, skip_prompt = True)
|
||||
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
|
||||
use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)
|
||||
|
||||
|
||||
# ### Saving to float16 for VLLM
|
||||
#
|
||||
# We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options.
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
# Select ONLY 1 to save! (Both not needed!)
|
||||
|
||||
# Save locally to 16bit
|
||||
if False: model.save_pretrained_merged("unsloth_finetune", processor,)
|
||||
|
||||
# To export and save to your Hugging Face account
|
||||
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", processor, token = "YOUR_HF_TOKEN")
|
||||
|
||||
|
||||
# And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
|
||||
#
|
||||
# Some other resources:
|
||||
# 1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
|
||||
# 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
|
||||
# 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
|
||||
# 4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!
|
||||
#
|
||||
# <div class="align-center">
|
||||
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
|
||||
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
|
||||
# <a href="https://unsloth.ai/docs/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
|
||||
#
|
||||
# Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
|
||||
# </div>
|
||||
#
|
||||
# This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
|
||||
@@ -0,0 +1,88 @@
|
||||
# CodeGemma
|
||||
|
||||
Code completion / generation with native **fill-in-the-middle (FIM)** support. Built on **Gemma 1** — still the most recent generation as of April 2026. No CodeGemma 2/3/4 release.
|
||||
|
||||
## What it is
|
||||
|
||||
Gemma 1 fine-tuned on code. Trained with 80–90% FIM rate, 50/50 split between PSM (Prefix-Suffix-Middle) and SPM (Suffix-Prefix-Middle) formats. Designed for IDE autocomplete more than chat.
|
||||
|
||||
## Sizes
|
||||
|
||||
- **2B pretrained** — fast completion
|
||||
- **7B pretrained** — higher quality completion + FIM
|
||||
- **7B instruction-tuned** — code chat
|
||||
|
||||
Versioned point releases exist (2B 1.1, 7B-IT 1.1).
|
||||
|
||||
## Model card
|
||||
|
||||
- https://ai.google.dev/gemma/docs/codegemma/model_card
|
||||
- HF: https://huggingface.co/google/codegemma-7b
|
||||
- Tech report: https://arxiv.org/abs/2406.11409
|
||||
|
||||
## FIM tokens
|
||||
|
||||
```
|
||||
<|fim_prefix|> prefix-of-completion marker
|
||||
<|fim_suffix|> cursor/insertion-point marker
|
||||
<|fim_middle|> generation trigger
|
||||
<|file_separator|> multi-file boundary
|
||||
```
|
||||
|
||||
### PSM (Prefix-Suffix-Middle) template
|
||||
|
||||
```
|
||||
<|fim_prefix|>[code before cursor]<|fim_suffix|>[code after cursor]<|fim_middle|>
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
prompt = (
|
||||
"<|fim_prefix|>import datetime\n"
|
||||
"def calculate_age(birth_year):\n"
|
||||
" current_year = datetime.date.today().year\n"
|
||||
" <|fim_suffix|>\n"
|
||||
" return age<|fim_middle|>"
|
||||
)
|
||||
```
|
||||
|
||||
The model generates the middle chunk and halts.
|
||||
|
||||
### Multi-file context
|
||||
|
||||
Prepend referenced files separated by `<|file_separator|>`, then the target file in FIM format.
|
||||
|
||||
## Minimum invocation
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model_id = "google/codegemma-7b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
|
||||
prompt = "<|fim_prefix|>def fib(n):\n if n <= 1:\n return n\n <|fim_suffix|>\n return a<|fim_middle|>"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
out = model.generate(**inputs, max_new_tokens=128)
|
||||
print(tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False))
|
||||
```
|
||||
|
||||
## Ollama
|
||||
|
||||
`ollama pull codegemma:7b` or `codegemma:2b`. Ollama wraps the FIM tokens for you when you use its completion API with prefix/suffix.
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
- You need **IDE-grade FIM autocomplete** — CodeGemma was trained for it, base Gemma 4 was not.
|
||||
- You want a **2B code model** — base Gemma 4 skips this size (E2B is multimodal, not code-specialized).
|
||||
- You want **Ollama-native FIM** that tools like `continue.dev` can talk to.
|
||||
|
||||
Base Gemma 4 31B still beats CodeGemma 7B on LiveCodeBench, so for **agentic coding** (plan, write, execute) Gemma 4 or `qwen3-coder:30b` wins. CodeGemma is the inline-cursor-assistant niche.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Steel141 already has qwen3-coder:30b and qwen3-coder-next:79.7B — those are stronger than CodeGemma 7B. Only reason to pull CodeGemma is if you want a tiny 2B FIM model for a latency-sensitive editor integration on a Pi or on pve197 alongside the vision stack.
|
||||
@@ -0,0 +1,76 @@
|
||||
# DataGemma
|
||||
|
||||
LLM grounding with Google **Data Commons** — a public knowledge graph of 240B+ statistical data points (economics, health, demographics, science). Built on **Gemma 2 27B**. No Gemma 3 or 4 generation yet.
|
||||
|
||||
## What it is
|
||||
|
||||
Two flavors:
|
||||
|
||||
- **DataGemma RIG** (Retrieval-Interleaved Generation): Model is fine-tuned to emit inline Data Commons queries wrapped around its own claims. Outputs look like `The population of Sunnyvale is [__DC__("population of Sunnyvale") --> "152,200"]`. An external resolver substitutes the real stat.
|
||||
- **DataGemma RAG** (Retrieval-Augmented Generation): Standard RAG pipeline — query Data Commons, inject results into context, generate.
|
||||
|
||||
## Sizes
|
||||
|
||||
- **27B instruct** only (`datagemma-rig-27b-it`, `datagemma-rag-27b-it`).
|
||||
|
||||
## Model cards
|
||||
|
||||
- https://ai.google.dev/gemma/docs/datagemma
|
||||
- DeepMind: https://deepmind.google/models/gemma/datagemma/
|
||||
- HF RIG: https://huggingface.co/google/datagemma-rig-27b-it
|
||||
- HF RAG: https://huggingface.co/google/datagemma-rag-27b-it
|
||||
- Paper: https://docs.datacommons.org/papers/DataGemma-FullPaper.pdf
|
||||
|
||||
## Performance claim
|
||||
|
||||
Baseline Gemma 2 factuality on the 101-query statistical eval: **5–17%**. DataGemma RIG: **~58%**. The improvement is narrow (statistical claims only) but real.
|
||||
|
||||
## Prompt format
|
||||
|
||||
No special template. Plain natural-language input. The difference is in the **training** and the **output format**.
|
||||
|
||||
**RIG output example:**
|
||||
```
|
||||
Sunnyvale has [__DC__("total population of Sunnyvale CA") --> "152,200"]
|
||||
residents as of 2020, with a median age of [__DC__("median age of
|
||||
Sunnyvale CA") --> "34.8"].
|
||||
```
|
||||
|
||||
Post-processing: regex out the `[__DC__("...") --> "..."]` blocks and either (a) replace with resolved Data Commons values, or (b) render as inline citations.
|
||||
|
||||
**RAG flow:** query Data Commons first, inject tabular context, then prompt normally.
|
||||
|
||||
## Minimum invocation — RIG
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model_id = "google/datagemma-rig-27b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, device_map="auto", torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
prompt = "What are the demographic trends in Sunnyvale, California?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
out = model.generate(**inputs, max_new_tokens=1024)
|
||||
print(tokenizer.batch_decode(
|
||||
out[:, inputs["input_ids"].shape[1]:],
|
||||
skip_special_tokens=True
|
||||
)[0])
|
||||
```
|
||||
|
||||
Then run a resolver that extracts each `[__DC__(q) --> ""]` and hits the Data Commons API.
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
- You're building a **statistics-grounded assistant** (government data, public health, economic indicators) and need low hallucination on numbers.
|
||||
- You're okay with a **27B model** — DataGemma only ships at this size.
|
||||
- Your domain overlaps Data Commons coverage (US-heavy, but growing internationally).
|
||||
|
||||
Base Gemma 4 + a conventional RAG pipeline can do the same thing if you bring your own retriever. DataGemma's value is the **trained inline-citation behavior** (RIG) — Gemma 4 won't emit that format without prompting gymnastics.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Low. No current Seth project leans on statistical grounding. Niche for a news-summary use case (POS-Automation daily print) if Seth ever wants "US inflation was X% as of Y" kind of interjections — but then a simple Data Commons API call from the script is cheaper than running a 27B model.
|
||||
@@ -0,0 +1,44 @@
|
||||
# DolphinGemma
|
||||
|
||||
Marine biology / dolphin vocalization model. Developed with the Wild Dolphin Project (WDP) and Georgia Tech. Announced April 2025.
|
||||
|
||||
## Status
|
||||
|
||||
**Not publicly released as of April 2026.** DeepMind's page states "DolphinGemma is currently in development. On release, it will be openly available." No weights on Hugging Face, Kaggle, or Google AI for Developers. Google's 2025 post anticipated a summer 2025 open-source release; that slipped.
|
||||
|
||||
If you see a `dolphingemma-*` tag somewhere, it is either community-named (not Google) or a leaked checkpoint. Verify the uploader is `google/` on HF.
|
||||
|
||||
## What it is (from announcement material)
|
||||
|
||||
- **Audio-in, audio-out** model.
|
||||
- Trained on tens of thousands of hours of Atlantic spotted dolphin vocalizations.
|
||||
- Predicts the next sound in a sequence (same training objective as an LLM, just in the audio token domain).
|
||||
- **~400M parameters** — small enough to run on a Pixel phone in the field.
|
||||
- Intended to plug into the CHAT (Cetacean Hearing Augmentation Telemetry) system to accelerate real-time pattern recognition during dolphin interactions.
|
||||
|
||||
## Base generation
|
||||
|
||||
Announced as "built on Google's open Gemma series." Google has not disclosed which generation. Given the mid-2025 timing and 400M size, most likely Gemma 3-era tech, but **this is an educated guess**, not confirmed.
|
||||
|
||||
## Model card
|
||||
|
||||
- DeepMind: https://deepmind.google/models/gemma/dolphingemma/
|
||||
- Blog: https://blog.google/innovation-and-ai/products/dolphingemma/
|
||||
|
||||
No model card on ai.google.dev yet (expected once released).
|
||||
|
||||
## Prompt format
|
||||
|
||||
Not published. The audio-token I/O format will depend on the tokenizer Google picked (e.g., SoundStream, Whisper-style, or a custom cetacean-phoneme tokenizer). Wait for release.
|
||||
|
||||
## Minimum invocation
|
||||
|
||||
Not possible. No weights available.
|
||||
|
||||
## When to choose it
|
||||
|
||||
If and when it ships: marine biology research, specifically Atlantic spotted dolphins. Fine-tunable for other cetacean species per Google.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Zero for normal use. If it ships and Seth wants a novelty "run the model on a cheap Pi and watch it hallucinate dolphin whistles" project, it's a candidate for the 400M-parameter slot on seth-pi. Until then, nothing to deploy.
|
||||
@@ -0,0 +1,93 @@
|
||||
# EmbeddingGemma
|
||||
|
||||
On-device text embedding model. Released **September 2025**. Built on **Gemma 3 with T5Gemma initialization**. No Gemma 4 generation yet.
|
||||
|
||||
## What it is
|
||||
|
||||
A **308M-parameter** open embedding model. Trained on 100+ languages. State-of-the-art on MTEB for its size class. Uses **Matryoshka Representation Learning (MRL)** — one model produces embeddings at 768, 512, 256, or 128 dimensions by truncation + renormalization, with graceful quality degradation.
|
||||
|
||||
## Sizes
|
||||
|
||||
- **308M** — only size.
|
||||
|
||||
## Model card
|
||||
|
||||
- https://ai.google.dev/gemma/docs/embeddinggemma/model_card
|
||||
- HF: https://huggingface.co/google/embeddinggemma-300m
|
||||
- HF blog: https://huggingface.co/blog/embeddinggemma
|
||||
- DeepMind: https://deepmind.google/models/gemma/embeddinggemma/
|
||||
- Paper: https://arxiv.org/html/2509.20354v2
|
||||
|
||||
## Prompt format
|
||||
|
||||
EmbeddingGemma uses **task-prefixed inputs** — you prepend a task descriptor to each string before embedding.
|
||||
|
||||
### Query prompts
|
||||
|
||||
```
|
||||
task: {task description} | query: {your query}
|
||||
```
|
||||
|
||||
Default task description: `search result`.
|
||||
|
||||
Example: `task: search result | query: what is the capital of France?`
|
||||
|
||||
### Document prompts
|
||||
|
||||
```
|
||||
title: {title or "none"} | text: {document text}
|
||||
```
|
||||
|
||||
Providing a real title improves retrieval; use `none` if unavailable.
|
||||
|
||||
Example: `title: Eiffel Tower | text: The Eiffel Tower is a wrought-iron lattice tower...`
|
||||
|
||||
## Minimum invocation
|
||||
|
||||
### Sentence-Transformers (easy path)
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer("google/embeddinggemma-300m")
|
||||
|
||||
query = "Which planet is known as the Red Planet?"
|
||||
documents = [
|
||||
"Mars, known for its reddish appearance, is often referred to as the Red Planet.",
|
||||
"Venus is often called Earth's twin due to its similar size.",
|
||||
]
|
||||
|
||||
q_emb = model.encode_query(query)
|
||||
d_emb = model.encode_document(documents)
|
||||
|
||||
print(model.similarity(q_emb, d_emb))
|
||||
```
|
||||
|
||||
The `encode_query` / `encode_document` methods apply the task prefixes automatically.
|
||||
|
||||
### Shorter embeddings (MRL)
|
||||
|
||||
```python
|
||||
emb_768 = model.encode(text) # full
|
||||
emb_256 = emb_768[:, :256] # truncate
|
||||
emb_256 = emb_256 / emb_256.norm(dim=-1, keepdim=True) # renormalize
|
||||
```
|
||||
|
||||
## Gotcha
|
||||
|
||||
**Activations do not support `float16`.** Use `bfloat16` or `float32`. This is explicit in the model card.
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
Always, when you want embeddings. Base Gemma 4 is a generative decoder — not trained as an embedding model. EmbeddingGemma is the correct tool for retrieval, clustering, semantic search, RAG.
|
||||
|
||||
Its main competitor is `nomic-embed-text` (already in Seth's pantry). EmbeddingGemma's MRL and multilingual coverage (100+ vs. nomic's ~English-focused) are the differentiators.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
**Highest-impact variant for Seth right now, along with TranslateGemma.**
|
||||
|
||||
- **Family history agent:** 100+ language support + 128d embeddings = tight, multilingual indices over scanned documents, letters, census records. MRL lets you serve fast 128d approximate search and fall back to 768d for reranking.
|
||||
- **SearXNG / SethSearch:** drop-in upgrade from nomic-embed-text for the semantic-search layer. Bigger model but better quality.
|
||||
- **Mortdecai memory:** use 308M EmbeddingGemma for long-term memory over chat logs. Small enough to run alongside the big mortdecai qwen35 models on pve197 or steel141 without resource contention.
|
||||
- **Gemma-cookbook already has a tutorial** (`tutorials_RAG_EmbeddingGemma.ipynb` in the corpus) — skip straight to working code.
|
||||
@@ -0,0 +1,55 @@
|
||||
# Gemma family index (as of April 2026)
|
||||
|
||||
Specialized sister models Google has released alongside base Gemma. Base Gemma 4 instruct/base variants are **not** listed here — they live in the main corpus at `/home/claude/bin/gemma4-research/`.
|
||||
|
||||
## Summary table
|
||||
|
||||
| Variant | Base gen | Sizes | Canonical use case | HF URL |
|
||||
|---|---|---|---|---|
|
||||
| **ShieldGemma** | Gemma 2 | 2B, 9B, 27B | Text safety classification (4 harm types) | [google/shieldgemma-2b](https://huggingface.co/google/shieldgemma-2b) |
|
||||
| **ShieldGemma 2** | Gemma 3 | 4B | Image safety classification (3 categories) | [google/shieldgemma-2-4b-it](https://huggingface.co/google/shieldgemma-2-4b-it) |
|
||||
| **CodeGemma** | Gemma 1 | 2B, 7B, 7B-IT | Code completion with FIM tokens | [google/codegemma-7b](https://huggingface.co/google/codegemma-7b) |
|
||||
| **PaliGemma** | Gemma 1 | 3B | Vision-language (task-prefix prompting) | [google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448) |
|
||||
| **PaliGemma 2** | Gemma 2 | 3B, 10B, 28B | Vision-language, multi-resolution | [google/paligemma2-3b-pt-448](https://huggingface.co/google/paligemma2-3b-pt-448) |
|
||||
| **RecurrentGemma** | Gemma 1 | 2B, 9B | Griffin architecture, long-context throughput | [google/recurrentgemma-9b](https://huggingface.co/google/recurrentgemma-9b) |
|
||||
| **DataGemma (RIG/RAG)** | Gemma 2 | 27B | Statistical grounding via Google Data Commons | [google/datagemma-rig-27b-it](https://huggingface.co/google/datagemma-rig-27b-it) |
|
||||
| **MedGemma 1.5** | Gemma 3 | 4B multimodal | Medical text + image comprehension (non-clinical) | [google/medgemma-1.5-4b-it](https://huggingface.co/google/medgemma-1.5-4b-it) |
|
||||
| **TxGemma** | Gemma 2 | 2B, 9B, 27B | Therapeutics/drug-discovery prediction | [google/txgemma-27b-predict](https://huggingface.co/google/txgemma-27b-predict) |
|
||||
| **DolphinGemma** | Gemma (unstated) | ~400M | Marine biology / dolphin vocalization | *Not released as of April 2026* |
|
||||
| **SignGemma** | Gemma 3-era | small on-device | ASL → English translation | *Limited preview only; no public weights as of April 2026* |
|
||||
| **TranslateGemma** | Gemma 3 | 4B, 12B, 27B | 55-language text + image translation | [google/translategemma-4b-it](https://huggingface.co/google/translategemma-4b-it) |
|
||||
| **EmbeddingGemma** | Gemma 3 (T5Gemma init) | 308M | On-device text embeddings, MRL (768/512/256/128) | [google/embeddinggemma-300m](https://huggingface.co/google/embeddinggemma-300m) |
|
||||
| **T5Gemma / T5Gemma 2** | Gemma 2 / Gemma 3 | small → 4B-4B | Encoder-decoder for summarization, translation | [google/t5gemma-2-4b-4b](https://huggingface.co/google/t5gemma-2-4b-4b) |
|
||||
| **FunctionGemma** | Gemma 3 | 270M | Function-calling specialist | [google/functiongemma-270m](https://huggingface.co/google/functiongemma-270m) |
|
||||
| **VaultGemma** | Gemma 3 | 1B | Differential-privacy-trained LLM | [google/vaultgemma-1b](https://huggingface.co/google/vaultgemma-1b) |
|
||||
| **Gemma-APS** | Gemma 2 | 2B, 7B | Abstractive proposition segmentation | — |
|
||||
| **Gemma Scope / Scope 2** | Gemma 2/3 | SAE suite | Mechanistic interpretability | [google/gemma-scope](https://huggingface.co/google/gemma-scope) |
|
||||
|
||||
## Gemma 4 generation status
|
||||
|
||||
**As of 2026-04-18, no specialized sister model has been re-based to Gemma 4.** Every variant in the table above is built on Gemma 1, 2, or 3. The newest specialized releases (TranslateGemma, Jan 2026; T5Gemma 2, Dec 2025) still sit on Gemma 3. This is normal for Google's cadence — sisters lag the base release by 3–6 months. Expect a MedGemma-on-Gemma-4, ShieldGemma-3-on-Gemma-4, and PaliGemma 3 over summer/fall 2026.
|
||||
|
||||
## Per-variant files
|
||||
|
||||
- `shieldgemma.md` — covers both ShieldGemma (text) and ShieldGemma 2 (image)
|
||||
- `codegemma.md`
|
||||
- `paligemma.md` — covers both PaliGemma and PaliGemma 2
|
||||
- `recurrentgemma.md`
|
||||
- `datagemma.md`
|
||||
- `medgemma.md`
|
||||
- `txgemma.md`
|
||||
- `dolphingemma.md`
|
||||
- `signgemma.md`
|
||||
- `translategemma.md`
|
||||
- `embeddinggemma.md`
|
||||
- `other-variants.md` — T5Gemma, FunctionGemma, VaultGemma, Gemma-APS, Gemma Scope
|
||||
|
||||
## Picking a variant for homelab use
|
||||
|
||||
Short read — see individual files for depth.
|
||||
|
||||
- **Minecraft agent (Mortdecai):** consider `FunctionGemma` (270M) as a fast-path tool-router in front of the big `mortdecai:*` models. Today's setup uses the base `qwen35`/`mortdecai` tool calling, but FunctionGemma's 270M size makes it cheap enough to run as a gateway classifier.
|
||||
- **AI music video gen / visualizer:** `PaliGemma 2` for detailed captioning of reference frames; `ShieldGemma 2` to pre-filter generated output before publishing. Base Gemma 4 vision (tested in existing corpus) handles the "describe this image" job fine — reach for PaliGemma 2 when you need spatial grounding (detect/segment task prefixes).
|
||||
- **Family history agent:** `EmbeddingGemma` (308M) is the immediate win — small, multilingual, 100+ languages, MRL to 128d for tight indices. Pair with `TranslateGemma` if sources are in German/Polish/etc. For ingest of old scanned documents, `PaliGemma 2` + `TranslateGemma` handles image-embedded text translation.
|
||||
- **General safety pass for anything going public:** `ShieldGemma 2` for images, `ShieldGemma` (Gemma 2-based) for text. Both run comfortably on pve197's CT 105.
|
||||
- **Skip for homelab:** MedGemma (disclaimer-laden, not clinical-grade, niche), TxGemma (drug discovery, highly specialist), DolphinGemma (not released), SignGemma (limited preview, no weights).
|
||||
@@ -0,0 +1,72 @@
|
||||
# MedGemma
|
||||
|
||||
Medical-domain variant for text + image comprehension. Current release is **MedGemma 1.5** (Jan 13, 2026), built on **Gemma 3**. **No Gemma 4 generation.**
|
||||
|
||||
## What it is
|
||||
|
||||
Gemma 3 fine-tuned on de-identified medical corpora — clinical notes, radiology images, dermatology images, histopathology, etc. The multimodal variants use a SigLIP image encoder trained specifically on medical imagery (not the base SigLIP).
|
||||
|
||||
## Sizes
|
||||
|
||||
**MedGemma 1.5** (current): **4B multimodal IT only**. Previous 27B variants were in MedGemma 1; 1.5 currently ships 4B only with improvements in medical reasoning, records interpretation, and image interpretation.
|
||||
|
||||
**MedGemma 1** (prior): 4B multimodal, 27B text-only, 27B multimodal.
|
||||
|
||||
## Model card
|
||||
|
||||
- https://developers.google.com/health-ai-developer-foundations/medgemma/model-card
|
||||
- DeepMind: https://deepmind.google/models/gemma/medgemma/
|
||||
- Repo: https://github.com/google-health/medgemma
|
||||
- Tech report: https://arxiv.org/abs/2507.05201
|
||||
|
||||
## Intended use
|
||||
|
||||
"A starting point that enables more efficient development of downstream healthcare applications involving medical text and images." **Developer tool, not a clinical product.**
|
||||
|
||||
### Disclaimer (near-verbatim from model card)
|
||||
|
||||
> The outputs generated by MedGemma are not intended to directly inform clinical diagnosis, patient management decisions, treatment recommendations, or any other direct clinical practice applications. All outputs require independent verification and clinical correlation.
|
||||
|
||||
Terms of use are governed by **Health AI Developer Foundations** — a separate license from base Gemma's. Read it before shipping anything.
|
||||
|
||||
## Prompt format
|
||||
|
||||
Standard Gemma 3 chat template. Content messages accept `{"type": "image"}` and `{"type": "text"}`.
|
||||
|
||||
## Minimum invocation
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
from PIL import Image
|
||||
import requests, torch
|
||||
|
||||
pipe = pipeline(
|
||||
"image-text-to-text",
|
||||
model="google/medgemma-1.5-4b-it",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
img_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
|
||||
image = Image.open(requests.get(img_url, stream=True).raw)
|
||||
|
||||
messages = [{"role": "user", "content": [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": "Describe this chest X-ray. What anatomical structures are visible?"},
|
||||
]}]
|
||||
|
||||
out = pipe(text=messages, max_new_tokens=512)
|
||||
print(out[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
- You're building **healthcare dev tools** (medical image triage assistant, doctor-facing records summarizer, clinician education) and want the SigLIP-medical image encoder.
|
||||
- You can accept the Health AI Developer Foundations license and embed the disclaimers.
|
||||
- You need **medical-vocabulary fluency** (SNOMED, ICD, RxNorm) that base Gemma 4 doesn't have at the 4B size.
|
||||
|
||||
Use base Gemma 4 otherwise — including for health-adjacent content that isn't clinical (fitness logs, nutrition, sleep data).
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Zero. Seth is not running medical apps. Noted for completeness only.
|
||||
@@ -0,0 +1,74 @@
|
||||
# Other Gemma variants
|
||||
|
||||
Smaller / more specialized sisters that don't warrant a full file each. All on Gemma 2 or Gemma 3. **None on Gemma 4 as of April 2026.**
|
||||
|
||||
## T5Gemma / T5Gemma 2
|
||||
|
||||
**Encoder-decoder** Gemma, built by adapting decoder-only Gemma weights into a T5-style encoder-decoder via UL2 or PrefixLM pretraining.
|
||||
|
||||
- **T5Gemma** (Jul 2025): Gemma 2-based. Sizes include 2B-2B, 9B-2B, 9B-9B plus new T5-sized small/base/large/XL models.
|
||||
- **T5Gemma 2** (Dec 2025): Gemma 3-based. Sizes: 270M-270M, 1B-1B, 4B-4B. Multimodal (128K context).
|
||||
|
||||
### When to pick it
|
||||
|
||||
- **Summarization, translation, QA** where the encoder's separate bidirectional attention buys quality.
|
||||
- Anywhere a decoder-only Gemma feels wasteful for "read input, compress into short output" tasks.
|
||||
|
||||
HF: https://huggingface.co/google/t5gemma-2-4b-4b
|
||||
Blog: https://developers.googleblog.com/en/t5gemma/
|
||||
|
||||
## FunctionGemma
|
||||
|
||||
**270M tool/function-calling specialist.** Gemma 3-based. Released Dec 2025.
|
||||
|
||||
Trained to emit structured function calls given a tool catalog. Not a generalist chat model — feed it a user message + tool schemas and it picks the right tool. Tiny enough to run as a pre-router in front of a larger model.
|
||||
|
||||
### When to pick it
|
||||
|
||||
- **Minecraft agent (Mortdecai):** plausibly interesting — use it as a 270M gateway that classifies intent and picks one of the Mortdecai tools, then hands off to the bigger `mortdecai:*` model for reasoning. Latency/cost savings if the tool decision is hot-path.
|
||||
- Any agent where tool-selection volume is high and model call cost matters.
|
||||
|
||||
HF: search `google/functiongemma-270m`.
|
||||
|
||||
## VaultGemma
|
||||
|
||||
**1B Gemma 3 trained with differential privacy.** Released Sep 2025.
|
||||
|
||||
The point is the training process (DP-SGD with rigorous privacy budget) more than the weights per se. Useful as a reference checkpoint or for deployments where "model cannot have memorized training data" is a hard requirement.
|
||||
|
||||
### When to pick it
|
||||
|
||||
- Niche. You almost never need DP-trained weights unless you're in regulated space.
|
||||
|
||||
## Gemma-APS
|
||||
|
||||
**Abstractive Proposition Segmentation.** 2B and 7B on Gemma 2. Oct 2024.
|
||||
|
||||
Takes a passage, splits it into atomic propositions (self-contained factual statements). Useful for fact-checking, citation mapping, and as a preprocessing step for RAG indexing.
|
||||
|
||||
### When to pick it
|
||||
|
||||
- Building a **fact-verification pipeline** where you need to decompose generated text into checkable claims.
|
||||
- **Family history** — could decompose narrative biographical text into timestamped facts for structured storage.
|
||||
|
||||
## Gemma Scope / Gemma Scope 2
|
||||
|
||||
Sparse autoencoder (SAE) suites for **mechanistic interpretability** research. Gemma Scope on Gemma 2, Gemma Scope 2 on Gemma 3 (Dec 2025).
|
||||
|
||||
Not models you deploy for product work. Tools for "which neurons activate on what" research.
|
||||
|
||||
HF: https://huggingface.co/google/gemma-scope
|
||||
|
||||
### When to pick it
|
||||
|
||||
- Interpretability research only. Not a homelab deployment candidate.
|
||||
|
||||
## Summary of homelab relevance
|
||||
|
||||
| Variant | Homelab fit |
|
||||
|---|---|
|
||||
| T5Gemma 2 4B-4B | Moderate — summarization for the news-briefing printer |
|
||||
| FunctionGemma 270M | **High — tool-router for Mortdecai** |
|
||||
| VaultGemma | None |
|
||||
| Gemma-APS | Low-moderate — niche preprocessing step |
|
||||
| Gemma Scope | None (research tool) |
|
||||
@@ -0,0 +1,80 @@
|
||||
# PaliGemma / PaliGemma 2
|
||||
|
||||
Vision-language model combining a **SigLIP** image encoder with a Gemma text decoder. Separate product line from base Gemma 4's built-in vision. Still on Gemma 2 as of April 2026 — **no PaliGemma 3 or PaliGemma-on-Gemma-4 yet.**
|
||||
|
||||
## What it is
|
||||
|
||||
- **PaliGemma** (May 2024): Gemma 1 + SigLIP-So400m/14. Sizes: 3B only. Built for task-prefix prompting (`caption`, `detect`, `segment`, `ocr`).
|
||||
- **PaliGemma 2** (Dec 2024): Gemma 2 + SigLIP-So400m/14. Sizes: 3B, 10B, 28B. Each available at three resolutions: 224x224, 448x448, 896x896.
|
||||
- **PaliGemma 2 mix** (Feb 2025): task-mixed instruction-tuned variant — works better out-of-the-box on ad-hoc VQA without per-task fine-tuning.
|
||||
|
||||
## Sizes (PaliGemma 2)
|
||||
|
||||
| Text decoder | Image encoder | Total | Resolutions |
|
||||
|---|---|---|---|
|
||||
| Gemma 2 2B | SigLIP-So400m | ~3B | 224 / 448 / 896 |
|
||||
| Gemma 2 9B | SigLIP-So400m | ~10B | 224 / 448 / 896 |
|
||||
| Gemma 2 27B | SigLIP-So400m | ~28B | 224 / 448 / 896 |
|
||||
|
||||
## Model cards
|
||||
|
||||
- PaliGemma 2: https://ai.google.dev/gemma/docs/paligemma/model-card-2
|
||||
- DeepMind: https://deepmind.google/models/gemma/paligemma-2/
|
||||
- HF blog: https://huggingface.co/blog/paligemma2
|
||||
|
||||
## Prompt format
|
||||
|
||||
PaliGemma uses **task-prefix** prompting, not chat turns. Format:
|
||||
|
||||
```
|
||||
<image>{task} {args}
|
||||
```
|
||||
|
||||
Known task prefixes (not exhaustive; Google under-documents the full list):
|
||||
|
||||
| Prefix | Purpose | Example |
|
||||
|---|---|---|
|
||||
| `caption {lang}` | Image captioning | `<image>caption en` |
|
||||
| `ocr` | Read all text in image | `<image>ocr` |
|
||||
| `answer en {q}` | VQA | `<image>answer en what color is the car?` |
|
||||
| `detect {obj}` | Object detection (bounding boxes) | `<image>detect cat ; dog` |
|
||||
| `segment {obj}` | Segmentation masks | `<image>segment person` |
|
||||
|
||||
For `detect` and `segment`, output uses custom location (`<loc0123>`) and segmentation (`<seg000>`) tokens. You need the PaliGemma postprocessing routines to convert them to pixel coords.
|
||||
|
||||
## Minimum invocation — PaliGemma 2
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
from PIL import Image
|
||||
import requests, torch
|
||||
|
||||
model_id = "google/paligemma2-3b-mix-448"
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
image = Image.open(requests.get(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png",
|
||||
stream=True
|
||||
).raw).convert("RGB")
|
||||
|
||||
prompt = "<image>caption en"
|
||||
inputs = processor(prompt, image, return_tensors="pt").to("cuda")
|
||||
out = model.generate(**inputs, max_new_tokens=200)
|
||||
gen = processor.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
|
||||
print(gen)
|
||||
```
|
||||
|
||||
## When to choose it over base Gemma 4 vision
|
||||
|
||||
- You need **structured spatial output** — bounding boxes, segmentation masks. Base Gemma 4 vision returns freeform text; PaliGemma 2 returns grid-aligned location tokens.
|
||||
- You're doing **pure VQA or captioning at scale** and want a smaller, faster, task-specialized 3B model (vs. Gemma 4 E4B at 4B-effective).
|
||||
- You're **fine-tuning** for a narrow vision task — PaliGemma 2 is explicitly designed to be easy to fine-tune; Google ships LoRA recipes.
|
||||
|
||||
Use base Gemma 4 for **conversational multimodal** (back-and-forth with images + text reasoning). PaliGemma is the "turn image into structured text" workhorse.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
For `ai-visualizer` (CT 167, pve197 with V100): PaliGemma 2 3B-448 is a great caption-and-ground step when producing SDXL prompts from reference images. Already tested: base Gemma 4 E4B handles "describe this image" at ~25 tok/s on pve197. PaliGemma 2 would add `detect`/`segment` for spatial control (e.g., "put the character in the upper-left quadrant of the generated scene").
|
||||
@@ -0,0 +1,67 @@
|
||||
# RecurrentGemma
|
||||
|
||||
Griffin-architecture sibling. Built on **Gemma 1**. No Gemma 2/3/4 generation — the line has effectively stalled, with long-context Transformer variants (Gemma 4 with 256K context) overtaking the memory-efficiency argument.
|
||||
|
||||
## What it is
|
||||
|
||||
Gated linear recurrences + local sliding-window attention, replacing full self-attention. Fixed-size hidden state → **O(1) memory per token generated**, no KV cache growth. Inference stays fast and cheap as context lengthens.
|
||||
|
||||
## Sizes
|
||||
|
||||
- **2B** pretrained + instruct
|
||||
- **9B** pretrained + instruct
|
||||
|
||||
Only two sizes. No 27B. Griffin scaling beyond 9B is an open research question and Google didn't ship it.
|
||||
|
||||
## Model card
|
||||
|
||||
- https://ai.google.dev/gemma/docs/recurrentgemma/model_card
|
||||
- DeepMind: https://deepmind.google/models/gemma/recurrentgemma/
|
||||
- Paper: https://arxiv.org/abs/2404.07839
|
||||
- Repo: https://github.com/google-deepmind/recurrentgemma
|
||||
|
||||
## Architecture highlights
|
||||
|
||||
- **Griffin block:** alternates two residual recurrent blocks with a local MQA attention block.
|
||||
- **State size:** fixed — independent of sequence length.
|
||||
- **Sliding window:** local attention only, not global.
|
||||
- **Trade-off:** loses some needle-in-haystack precision vs. a full-attention Transformer, gains memory flatness.
|
||||
|
||||
## Prompt format
|
||||
|
||||
Standard Gemma turn format — same `<start_of_turn>user … <end_of_turn>` as Gemma 1 IT. No RecurrentGemma-specific tokens.
|
||||
|
||||
## Minimum invocation
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model_id = "google/recurrentgemma-9b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
|
||||
prompt = "<start_of_turn>user\nWrite a haiku about memory.<end_of_turn>\n<start_of_turn>model\n"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
out = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(out[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
Honestly: **rarely, in April 2026.**
|
||||
|
||||
The original pitch was "long-context generation without KV blowup." Gemma 4 now ships with 256K context on the 26B/31B and 128K on the edge models, with efficient attention implementations. The gap RecurrentGemma was filling has narrowed.
|
||||
|
||||
Reasonable residual cases:
|
||||
- **Extremely memory-constrained hardware** (Jetson Nano tier) where even quantized Gemma 4 E2B KV cache is the limiting factor on sequence length.
|
||||
- **Streaming-generation workloads** where latency-per-token must stay constant as output length grows into the tens of thousands of tokens.
|
||||
- **Research interest** in recurrent LLMs.
|
||||
|
||||
For typical homelab use, skip. The V100 on pve197 has 32GB VRAM; Gemma 4 31B at Q4 fits with room for generous context.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Not a strong candidate for any current Seth project. Note for file: if a CPU-only streaming-transcript use case ever comes up (e.g., running on seth-pi for always-on audio processing), RecurrentGemma 2B could reappear in scope.
|
||||
@@ -0,0 +1,89 @@
|
||||
# ShieldGemma / ShieldGemma 2
|
||||
|
||||
Safety classifiers. Two separate product lines now: **ShieldGemma** (text, built on Gemma 2) and **ShieldGemma 2** (images, built on Gemma 3). There is no Gemma 4 generation yet.
|
||||
|
||||
## What it is
|
||||
|
||||
- **ShieldGemma (text):** LLM-as-a-judge safety classifier. Takes a prompt + optional model response + a policy, emits `Yes`/`No` (yes = violates policy). Four harm types.
|
||||
- **ShieldGemma 2 (image):** Image classifier. Takes a PIL image, emits probabilities across three image-safety categories. Image-only — does not accept text.
|
||||
|
||||
## Sizes
|
||||
|
||||
- ShieldGemma: **2B, 9B, 27B** — all instruction-tuned.
|
||||
- ShieldGemma 2: **4B** — only size available.
|
||||
|
||||
## Model cards
|
||||
|
||||
- Text: https://ai.google.dev/gemma/docs/shieldgemma/model_card
|
||||
- Image: https://huggingface.co/google/shieldgemma-2-4b-it
|
||||
- DeepMind: https://deepmind.google/models/gemma/shieldgemma-2/
|
||||
|
||||
## Safety categories
|
||||
|
||||
**ShieldGemma (text):**
|
||||
1. Sexually explicit content
|
||||
2. Dangerous content
|
||||
3. Hate speech
|
||||
4. Harassment
|
||||
|
||||
**ShieldGemma 2 (image):**
|
||||
1. Sexually explicit content
|
||||
2. Dangerous content
|
||||
3. Violence / gore
|
||||
|
||||
Note the image model dropped "hate" and "harassment" (hard to define visually) and added "violence/gore" (a visual primitive).
|
||||
|
||||
## Text prompt format (ShieldGemma)
|
||||
|
||||
Five-component structure:
|
||||
|
||||
```
|
||||
<preamble establishing "you are a policy expert">
|
||||
|
||||
<start_of_turn>user
|
||||
<user prompt here>
|
||||
<end_of_turn>
|
||||
|
||||
<start_of_turn>model
|
||||
<optional model response here>
|
||||
<end_of_turn>
|
||||
|
||||
Our safety principle is defined in the below:
|
||||
* <policy description, e.g. "No Hate Speech": ...>
|
||||
|
||||
Does the human prompt/response violate the above principle? Your answer must start with 'Yes' or 'No'.
|
||||
```
|
||||
|
||||
The model outputs one token: `Yes` (violates) or `No` (safe). Softmax the logits on those two tokens for a calibrated score.
|
||||
|
||||
## Minimum invocation — ShieldGemma 2 (image)
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
model_id = "google/shieldgemma-2-4b-it"
|
||||
model = ShieldGemma2ForImageClassification.from_pretrained(model_id).eval()
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
image = Image.open("input.jpg")
|
||||
inputs = processor(images=[image], return_tensors="pt")
|
||||
|
||||
with torch.inference_mode():
|
||||
out = model(**inputs)
|
||||
|
||||
print(out.probabilities) # tensor of per-category "Yes" probabilities
|
||||
```
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
- You need a **calibrated safety score**, not a free-form "is this safe?" answer from the chat model. ShieldGemma emits Yes/No token logits — easy to threshold.
|
||||
- You want **policy-by-policy classification** (e.g., run each category separately with different thresholds).
|
||||
- You're running a moderation pipeline and need **a small, fast, purpose-trained classifier** rather than a general chat model reasoning about safety.
|
||||
|
||||
Use base Gemma 4 for "explain *why* this is unsafe" narrative output. ShieldGemma is the yes/no stamp.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Pre-filter for `ai-visualizer` (CT 167, pve197) before publishing generated images. ShieldGemma 2 4B at Q4 fits comfortably on the Tesla V100-PCIE-32GB alongside SDXL.
|
||||
@@ -0,0 +1,43 @@
|
||||
# SignGemma
|
||||
|
||||
ASL (American Sign Language) → English translation model. Announced at Google I/O 2025.
|
||||
|
||||
## Status
|
||||
|
||||
**Limited preview only. No open weights as of April 2026.** Google published an interest form at I/O 2025; access has been gated to language-service providers, accessibility researchers, and members of the Deaf community. Participants receive a TensorFlow Lite package and sample integration code.
|
||||
|
||||
There is no public Hugging Face entry under `google/signgemma*`. The original plan was general availability by end-of-2025, which slipped. No updated timeline announced as of April 2026.
|
||||
|
||||
## What it is (from announcement material)
|
||||
|
||||
- **Video-in, text-out** on-device model.
|
||||
- Best performance on **ASL → English**; training includes other sign languages for future expansion.
|
||||
- Uses a **vision transformer** to analyze hand shapes, facial expressions, and motion, followed by a compact language model that produces English output.
|
||||
- Sized for **smartphones and laptops** — on-device real-time translation is the design goal.
|
||||
|
||||
## Base generation
|
||||
|
||||
Google states it is "part of the Gemma family" and "built on the Gemini Nano framework." Likely Gemma 3-era image/video encoder on a small Gemma 3 text decoder — **not confirmed**, and the "Gemini Nano framework" language suggests it may use Gemini-not-Gemma internals despite the name. Verify at release.
|
||||
|
||||
## Model card
|
||||
|
||||
- LinkedIn announcement: https://www.linkedin.com/posts/googledeepmind_signgemma-is-our-most-advanced-model-for-activity-7342957078249955329-JwJJ
|
||||
- Slator coverage: https://slator.com/google-invites-feedback-for-signgemma-a-new-ai-sign-language-translation-model/
|
||||
|
||||
No public model card yet.
|
||||
|
||||
## Prompt format
|
||||
|
||||
Not published.
|
||||
|
||||
## Minimum invocation
|
||||
|
||||
Not possible. No weights available.
|
||||
|
||||
## When to choose it
|
||||
|
||||
On release: accessibility apps, live captioning for Deaf users, sign-language learning tools.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Zero for typical homelab use. If Seth ever wants to pilot a real-time captioning overlay for video streams this could matter — but not buildable until Google ships weights.
|
||||
@@ -0,0 +1,105 @@
|
||||
# TranslateGemma
|
||||
|
||||
Multilingual text + image translation. Released **January 15, 2026**. Built on **Gemma 3** (not Gemma 4, despite being the newest variant at time of writing).
|
||||
|
||||
## What it is
|
||||
|
||||
Gemma 3 fine-tuned for translation across **55 languages**, using a two-stage distillation from Gemini. Retains Gemma 3's multimodal capability — can translate text embedded in images.
|
||||
|
||||
## Sizes
|
||||
|
||||
- **4B IT**
|
||||
- **12B IT**
|
||||
- **27B IT**
|
||||
|
||||
Google's headline claim: the 12B beats Gemma 3 27B baseline translation quality with less than half the parameters.
|
||||
|
||||
## Model card
|
||||
|
||||
- HF: https://huggingface.co/google/translategemma-4b-it
|
||||
- Blog: https://blog.google/innovation-and-ai/technology/developers-tools/translategemma/
|
||||
- InfoQ: https://www.infoq.com/news/2026/01/google-translategemma-models/
|
||||
|
||||
## Supported languages
|
||||
|
||||
55 languages via ISO 639-1 codes (`en`, `de`, `es`, `fr`, `pl`, `ja`, `zh`, `ar`, `hi`, etc.) plus regional variants (`en-US`, `en-GB`, `pt-BR`, `pt-PT`, `de-DE`, `de-AT`, `de-CH`, `zh-CN`, `zh-TW`, etc.).
|
||||
|
||||
## Prompt format
|
||||
|
||||
**Strict chat-template format.** Content list must contain exactly **one entry**, with mandatory `source_lang_code` and `target_lang_code`.
|
||||
|
||||
### Text translation
|
||||
|
||||
```python
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"source_lang_code": "cs",
|
||||
"target_lang_code": "de-DE",
|
||||
"text": "V nejhorším případě i k prasknutí čočky.",
|
||||
}],
|
||||
}]
|
||||
```
|
||||
|
||||
### Image translation (translates text inside the image)
|
||||
|
||||
```python
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image",
|
||||
"source_lang_code": "ja",
|
||||
"target_lang_code": "en",
|
||||
"url": "https://example.com/japanese-sign.jpg",
|
||||
}],
|
||||
}]
|
||||
```
|
||||
|
||||
Only `"text"` and `"image"` types are supported. Only `user` and `assistant` roles. Image input is normalized to 896×896 (256 vision tokens).
|
||||
|
||||
## Minimum invocation
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
pipe = pipeline(
|
||||
"image-text-to-text",
|
||||
model="google/translategemma-4b-it",
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"source_lang_code": "pl",
|
||||
"target_lang_code": "en",
|
||||
"text": "Dziadek mieszkał w Warszawie przed wojną.",
|
||||
}],
|
||||
}]
|
||||
|
||||
out = pipe(text=messages, max_new_tokens=200)
|
||||
print(out[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
- **WMT24++ across 55 languages:** MetricX 5.32, COMET 81.6.
|
||||
- Context window: 2K tokens (short — this is a translation model, not a long-doc summarizer).
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
- You want **translation quality > general Gemma 4** at equivalent size, with the strict prompt contract making it easy to drop into a pipeline.
|
||||
- You need **image-text translation** (street signs, menus, old documents) as a first-class task.
|
||||
- You care about the 55-language coverage and regionalized variants.
|
||||
|
||||
Base Gemma 4 31B *can* translate — fine for casual use. TranslateGemma wins for production pipelines and when you care about metric-validated quality.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
**Strong fit for family history agent.** If source documents are in German, Polish, Hungarian, Yiddish, or any of the 55 supported languages, TranslateGemma 4B on pve197 (GPU-backed) becomes the translation leg of an ingest pipeline: OCR → TranslateGemma → Gemma 4 for reasoning. The 4B size fits alongside the other models on the V100.
|
||||
|
||||
Also useful for SearchXNG (if Seth ever wants to auto-translate non-English search results) and the news-summary print system (translate foreign-language feeds before summarization).
|
||||
@@ -0,0 +1,63 @@
|
||||
# TxGemma
|
||||
|
||||
Therapeutic-development / drug-discovery variant. Built on **Gemma 2**. No Gemma 3 or 4 generation yet.
|
||||
|
||||
## What it is
|
||||
|
||||
Gemma 2 fine-tuned on 7M examples curated from the **Therapeutics Data Commons (TDC)** — predictive tasks across small molecules, proteins, nucleic acids, diseases, and cell lines. Beats or matches state-of-the-art on 50 of 66 TDC tasks; beats specialist models on 26 of them.
|
||||
|
||||
## Sizes
|
||||
|
||||
- **2B predict** — prediction-only, narrow prompt format.
|
||||
- **9B predict** + **9B chat** — prediction plus conversational reasoning.
|
||||
- **27B predict** + **27B chat** — same, larger.
|
||||
|
||||
## Model card
|
||||
|
||||
- https://developers.google.com/health-ai-developer-foundations/txgemma/model-card
|
||||
- DeepMind: https://deepmind.google/models/gemma/txgemma/
|
||||
- Paper: https://deepmind.google/research/publications/153799/
|
||||
|
||||
## Prompting modes
|
||||
|
||||
**Prediction mode** (all sizes): structured TDC-format prompt with instruction + context + question + optional few-shot. Output is a short prediction (sometimes a single token or a float).
|
||||
|
||||
**Conversational mode** (9B, 27B): chat-template interactions, can explain reasoning behind predictions.
|
||||
|
||||
## Minimum invocation — prediction
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model="google/txgemma-27b-predict",
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"Instructions: Predict whether the molecule can penetrate the blood-brain barrier.\n"
|
||||
"Context: Blood-brain barrier penetration is an important property for CNS drugs.\n"
|
||||
"Question: Given the SMILES string CN1C=NC2=C1C(=O)N(C(=O)N2C)C, "
|
||||
"predict BBB penetration. Answer with 'Yes' or 'No'.\n"
|
||||
"Answer:"
|
||||
)
|
||||
|
||||
out = pipe(prompt, max_new_tokens=8)
|
||||
print(out[0]["generated_text"])
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Health AI Developer Foundations — same terms as MedGemma. Non-clinical, research-use.
|
||||
|
||||
## When to choose it over base Gemma 4
|
||||
|
||||
- You're doing **drug-discovery research** and need TDC-format predictions out of the box.
|
||||
- You want **SMILES-aware reasoning** without a custom cheminformatics stack.
|
||||
|
||||
Almost never chosen for general-purpose work. TxGemma's value is the training data, not the base model.
|
||||
|
||||
## Homelab fit
|
||||
|
||||
Zero. Noted for completeness.
|
||||
@@ -0,0 +1,226 @@
|
||||
# Google-official Gemma tooling (as of 2026-04-18)
|
||||
|
||||
Downloaded corpus of canonical Google / Google-DeepMind Gemma tooling. This
|
||||
directory mirrors only **upstream-authored** material — no third-party forks,
|
||||
no community ports, no Ollama-specific content (that lives in
|
||||
`../../CORPUS_ollama_variants.md`).
|
||||
|
||||
Reach for this directory when you need to verify what the canonical code/docs
|
||||
actually say (prompt tokens, API shapes, supported variants) versus what a
|
||||
third-party wrapper claims they say.
|
||||
|
||||
## Top-line findings (flag for cross-check with rest of corpus)
|
||||
|
||||
1. **Canonical JAX/Flax library (`google-deepmind/gemma`) has first-class
|
||||
Gemma 4 support today** — `gm.nn.Gemma4_E4B()`,
|
||||
`gm.ckpts.CheckpointPath.GEMMA4_E4B_IT`, and the unified `ChatSampler` /
|
||||
`ToolSampler` API explicitly lists "2, 3, 3n, 4" as supported. This is the
|
||||
least-friction Python path if you want the actual reference behavior.
|
||||
2. **`google/gemma_pytorch` has NO Gemma 4 support** as of last push
|
||||
(2025-05-30). `scripts/run.py` validates variant in
|
||||
`['2b', '2b-v2', '7b', '9b', '27b', '1b']`; `scripts/run_multimodal.py` in
|
||||
`['4b', '12b', '27b_v3']` (all Gemma 3). If someone tells you to "use
|
||||
the official PyTorch repo" for Gemma 4, they're wrong — it's stale.
|
||||
3. **`google/gemma.cpp` README says Gemma 2-3 + PaliGemma 2 only** (no Gemma 4
|
||||
yet), but the repo is actively pushed and explicitly notes active work
|
||||
happens on the `dev` branch. Worth rechecking `dev` for Gemma 4 support.
|
||||
4. **Gemma 4 uses a NEW prompt-token syntax** distinct from Gemma 1/2/3:
|
||||
- Gemma 1/2/3: `<start_of_turn>` / `<end_of_turn>` (symmetric angle brackets)
|
||||
- Gemma 4: `<|turn>` / `<turn|>` (asymmetric pipe-brackets)
|
||||
- Plus Gemma-4-new: `<|tool>`/`<tool|>`, `<|tool_call>`/`<tool_call|>`,
|
||||
`<|tool_response>`/`<tool_response|>`, `<|think|>`,
|
||||
`<|channel>`/`<channel|>`, `<|image>`/`<image|>`, `<|audio>`/`<audio|>`,
|
||||
string delimiter `<|"|>`.
|
||||
- Roles are named directly: `system`, `user`, `model` (no role brackets).
|
||||
This directly contradicts any chat template built against Gemma 3 tokens.
|
||||
`CORPUS_tool_calling_format.md` already captures the tool tokens correctly
|
||||
but does NOT yet document the turn-token change or the thinking tokens.
|
||||
5. **`gemma.cpp` ships an HTTP API server (`gemma_api_server`) that speaks
|
||||
the Google Gemini API protocol** (`POST /v1beta/models/<model>:generateContent`,
|
||||
SSE streaming, session management). This is a canonical Google-built
|
||||
alternative to Ollama that implements the *real* Gemini REST API locally.
|
||||
See `gemma-cpp/API_SERVER_README.md`.
|
||||
6. **Tool use was NOT a trained capability in Gemma 1/2/3** — the DeepMind
|
||||
`colabs/tool_use.ipynb` explicitly disclaims: *"The Gemma 1, 2 and 3 models
|
||||
were not specifically trained for tool use. This is more a proof-of-concept
|
||||
than an officially supported feature."* Gemma 4 is notably absent from that
|
||||
caveat; the cookbook and blog confirm Gemma 4 has **native function
|
||||
calling** as a first-class trained capability.
|
||||
7. **No Gemma 4 technical-report PDF exists yet.** All conventional URLs
|
||||
(`storage.googleapis.com/deepmind-media/gemma/Gemma4Report.pdf`,
|
||||
`goo.gle/gemma4report`) return 404/redirect-to-google.com, and the
|
||||
DeepMind repo README explicitly says "Gemma 4 (Coming soon)". Current
|
||||
most-authoritative scientific document for the family is the Gemma 3
|
||||
technical report (arXiv:2503.19786), downloaded here.
|
||||
8. **Cookbook ships a Gemma-4-specific agentic reference app**
|
||||
(`apps/Gemma_4_HDP_Agentic_Security/`) demonstrating how to cryptographically
|
||||
gate Gemma 4's native function calls with Ed25519-signed delegation tokens
|
||||
(IETF draft `draft-helixar-hdp-agentic-delegation-00`). A more
|
||||
production-shaped pattern than the toy `tool_use.ipynb`.
|
||||
|
||||
## File index
|
||||
|
||||
### `deepmind-gemma/` — JAX/Flax reference (the primary Python library)
|
||||
Upstream: https://github.com/google-deepmind/gemma (`main`, pushed 2026-04-17).
|
||||
|
||||
| File | What | Why keep |
|
||||
|------|------|----------|
|
||||
| `README.md` | PyPI `gemma` package entry point | Shows canonical `gm.nn.Gemma4_E4B()` API, `ChatSampler` multi-turn/multi-modal example |
|
||||
| `example_multimodal.py` | Image-captioning fine-tune (Kauldron config) | Canonical end-to-end SFT example; docstring shows exact `<start_of_turn>user / <start_of_image> / <end_of_turn>` interleave for Gemma 3 |
|
||||
| `example_lora.py` | LoRA fine-tuning recipe | Reach for this if doing PEFT against a Gemma 4 checkpoint |
|
||||
| `example_dpo.py` | Direct Preference Optimization recipe | Reference for preference-alignment post-training |
|
||||
| `example_classification.py` | Classification fine-tune | Shows Gemma as a feature extractor |
|
||||
| `example_sharding.py` | Multi-device sharding | Reference for running >E4B on multi-GPU/TPU |
|
||||
| `colab_tool_use.ipynb` | Tool-use demo (`ToolSampler`) | Important caveat inside: "not specifically trained for tool use" for Gemma 1/2/3; shows the `gm.tools.Tool` base class API |
|
||||
| `colab_sampling.ipynb` | Basic inference / chat notebook | Starter-grade canonical sampling example |
|
||||
|
||||
Other scripts in the repo (not downloaded, cherry-picked above): `seq2seq.py`, `npo.py`, colabs for `quantization_aware_training`, `sharding`, `tokenizer`, `multimodal`, `finetuning`, `lora_finetuning`, `lora_sampling`. Fetch directly from https://github.com/google-deepmind/gemma/tree/main when needed.
|
||||
|
||||
### `gemma-pytorch/` — PyTorch reference (STALE for Gemma 4)
|
||||
Upstream: https://github.com/google/gemma_pytorch (`main`, pushed 2025-05-30).
|
||||
|
||||
| File | What | Why keep |
|
||||
|------|------|----------|
|
||||
| `README.md` | Entry-point docs | Only documents up through Gemma 3; no Gemma 4 |
|
||||
| `run.py` | Text-only inference entry point | Variant whitelist `['2b','2b-v2','7b','9b','27b','1b']` — Gemma 1/2 only |
|
||||
| `run_multimodal.py` | Multimodal inference entry point | Variant whitelist `['4b','12b','27b_v3']` — Gemma 3 only. Shows exact interleaved `<start_of_turn>user\n`, image, `text, <end_of_turn>\n<start_of_turn>model` pattern |
|
||||
| `run_xla.py` | TPU/XLA inference | Reference for running Gemma 3 on TPU |
|
||||
|
||||
**Do not reach for this repo for Gemma 4 work** until it's updated. Use the
|
||||
DeepMind JAX lib, Hugging Face `transformers`, or gemma.cpp instead.
|
||||
|
||||
### `gemma-cpp/` — C++ reference inference
|
||||
Upstream: https://github.com/google/gemma.cpp (`main`, pushed 2026-04-17; active dev on `dev` branch).
|
||||
|
||||
| File | What | Why keep |
|
||||
|------|------|----------|
|
||||
| `README.md` | Project overview, build instructions | States "Gemma 2-3 + PaliGemma 2" in features; Gemma 4 status unclear from `main` — check `dev` branch |
|
||||
| `API_SERVER_README.md` | HTTP API server that speaks Gemini API protocol | **Most interesting find** — canonical drop-in for apps written against the Gemini API, runs locally. `POST /v1beta/models/<model>:generateContent`, SSE streaming, session KV-cache |
|
||||
| `examples_README.md` | Pointer to `hello_world` / `simplified_gemma` minimal embedding examples | Starting point for embedding gemma.cpp into your own C++ binary |
|
||||
|
||||
### `cookbook/` — Official recipes and end-to-end apps
|
||||
Upstream: https://github.com/google-gemma/cookbook (`main`, pushed 2026-04-17).
|
||||
**Note:** `google-gemini/gemma-cookbook` now 301-redirects here; use the
|
||||
`google-gemma/cookbook` URL going forward.
|
||||
|
||||
| File | What | Why keep |
|
||||
|------|------|----------|
|
||||
| `README.md` | Cookbook index | Authoritative list of Gemma variants incl. Gemma 4 (E2B / E4B / 26B A4B / 31B), the ecosystem (FunctionGemma, MedGemma, PaliGemma 2, RecurrentGemma, ShieldGemma 2, T5Gemma, TranslateGemma, TxGemma, VaultGemma, EmbeddingGemma) |
|
||||
| `tutorials_RAG_EmbeddingGemma.ipynb` | RAG with EmbeddingGemma | Currently the only notebook in `tutorials/` — reflects the "latest tested" tier |
|
||||
| `docs_gemma_chat.ipynb` | Chatbot with Gemma on Keras | Documents the `__START_TURN_USER__ = "<start_of_turn>user\n"` / `__END_TURN__ = "<end_of_turn>\n"` format explicitly; Gemma 2 example, but the class is the canonical illustration of the Gemma 1/2/3 chat template |
|
||||
| `apps_Gemma4_HDP_AgenticSecurity_README.md` | README for the HDP agentic-security reference app | Gemma-4-specific demo; real production pattern for gating native function calls |
|
||||
| `apps_Gemma4_HDP_hdp_middleware.py` | Drop-in middleware (`HDPMiddleware.gate()`) | Wraps any Gemma 4 tool executor with Ed25519-signed HDT verification |
|
||||
| `apps_Gemma4_HDP_AgenticSecurity.ipynb` | Walkthrough notebook | End-to-end: load Gemma 4, issue tokens, gate function calls |
|
||||
|
||||
Other cookbook content worth noting (not downloaded — fetch on demand):
|
||||
- `docs/capabilities/thinking.ipynb` (438 KB) — Gemma 4 thinking-mode notebook
|
||||
- `docs/capabilities/audio.ipynb` — audio-input capability
|
||||
- `docs/functiongemma/{finetuning-with-functiongemma,full-function-calling-sequence-with-functiongemma,function-calling-with-hf}.ipynb` — **FunctionGemma** is a separate fine-tune on the Gemma 3 270M IT checkpoint specifically for function calling; distinct from Gemma 4's native function calling
|
||||
- `docs/core/pytorch_gemma.ipynb`, `keras_inference.ipynb`, `huggingface_*.ipynb` — framework-specific recipes
|
||||
- `docs/integrations/langchain.ipynb` — LangChain integration
|
||||
- `experiments/{MedGemma,TxGemma}/` and `experiments/[T5Gemma]Example.ipynb`, `[VaultGemma]FineTuning_Inference_Huggingface.ipynb`, etc. — domain-specific Gemma variants
|
||||
|
||||
### `docs/` — Canonical ai.google.dev pages (HTML cached)
|
||||
Verified URLs below; HTML snapshots saved for verbatim preservation.
|
||||
|
||||
| File | Source URL |
|
||||
|------|-----------|
|
||||
| `ai-google-dev_core.html` | https://ai.google.dev/gemma/docs/core — Gemma 4 overview |
|
||||
| `ai-google-dev_model_card_4.html` | https://ai.google.dev/gemma/docs/core/model_card_4 — Gemma 4 model card |
|
||||
| `ai-google-dev_prompt_formatting_gemma4.html` | https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4 — **Gemma 4 prompt tokens (new `<\|turn>`/`<turn\|>` syntax)** |
|
||||
| `ai-google-dev_function_calling_gemma4.html` | https://ai.google.dev/gemma/docs/capabilities/text/function-calling-gemma4 — **Gemma 4 native function calling spec** |
|
||||
| `ai-google-dev_formatting.html` | https://ai.google.dev/gemma/docs/formatting — Gemma 1/2/3 prompt format (`<start_of_turn>`/`<end_of_turn>`) |
|
||||
| `blog_announcement.html` | https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/ — Gemma 4 launch blog, 2026-04-02 |
|
||||
|
||||
Other canonical doc URLs (verified to exist, not snapshotted here — visit
|
||||
directly):
|
||||
- https://ai.google.dev/gemma/docs — top-level Gemma hub
|
||||
- https://ai.google.dev/gemma/docs/releases — release history
|
||||
- https://ai.google.dev/gemma/docs/functiongemma — FunctionGemma variant
|
||||
- https://ai.google.dev/gemma/docs/core/deploy_to_cloud_run_from_ai_studio — AI Studio → Cloud Run
|
||||
- https://docs.cloud.google.com/vertex-ai/generative-ai/docs/open-models/use-gemma — Vertex AI
|
||||
- https://aistudio.google.com — AI Studio
|
||||
- https://gemma-llm.readthedocs.io — DeepMind JAX lib docs
|
||||
- https://www.kaggle.com/models/google/gemma-4 — Gemma 4 on Kaggle
|
||||
- https://huggingface.co/collections/google/gemma-4 — Gemma 4 on HF
|
||||
|
||||
### `tech-report/`
|
||||
| File | What | Source |
|
||||
|------|------|--------|
|
||||
| `Gemma3Report.pdf` | **Gemma 3 Technical Report** (arXiv:2503.19786, 2025-03-12) | https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf |
|
||||
|
||||
No Gemma 4 technical report exists yet. Probed paths that return 404:
|
||||
- `Gemma4Report.pdf`, `gemma4-report.pdf`, `Gemma4Report_v1.pdf` under
|
||||
`storage.googleapis.com/deepmind-media/gemma/`
|
||||
- `goo.gle/gemma4report` (not configured — redirects to google.com)
|
||||
|
||||
DeepMind repo README line: **"Gemma 4 (Coming soon)"**. The Gemma 3 report
|
||||
remains the most-authoritative Google-DeepMind scientific document for the
|
||||
family and is the correct citation for architecture fundamentals (Grouped-Query
|
||||
Attention with post-norm/pre-norm RMSNorm, 5:1 local/global attention layer
|
||||
interleave, 1024-token local sliding window, RoPE base 1M on global / 10k on
|
||||
local, SigLIP 400M vision encoder at 896×896 shared across 4B/12B/27B and
|
||||
frozen during training, SentencePiece tokenizer with 262k vocab shared with
|
||||
Gemini 2.0, knowledge distillation during pre-training, QAT checkpoints via
|
||||
5k-step fine-tune for int4/SFP8). Per-variant parameter counts for Gemma 3:
|
||||
1B = 698M non-embedding + 302M embedding, 4B = 3209M + 675M, 12B = 10759M +
|
||||
1012M, 27B = 25600M + 1416M.
|
||||
|
||||
## Canonical Gemma 4 prompt format (verified 2026-04-18)
|
||||
|
||||
**Source:** https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4 and
|
||||
https://ai.google.dev/gemma/docs/capabilities/text/function-calling-gemma4
|
||||
|
||||
Note the `<|turn>` / `<turn|>` are asymmetric — opening has the pipe on the
|
||||
left, closing has the pipe on the right. Same for all paired delimiters.
|
||||
|
||||
```
|
||||
<|turn>system
|
||||
<|think|> (optional — activates thinking mode)
|
||||
<|tool>declaration:FUNCTION_NAME{description:<|"|>...<|"|>,parameters:{properties:{...},required:[...]}}<tool|>
|
||||
You are a helpful assistant.<turn|>
|
||||
<|turn>user
|
||||
What's the weather in Tokyo?<turn|>
|
||||
<|turn>model
|
||||
<|channel>thought
|
||||
...internal reasoning...<channel|>
|
||||
<|tool_call>call:get_current_weather{location:<|"|>Tokyo, JP<|"|>}<tool_call|>
|
||||
<|tool_response>response:get_current_weather{temperature:15,weather:<|"|>sunny<|"|>}<tool_response|>
|
||||
The current weather in Tokyo is 15 degrees and sunny.<turn|>
|
||||
```
|
||||
|
||||
Recommended sampling (per model card, verified):
|
||||
`temperature=1.0, top_p=0.95, top_k=64`. Tokenizer vocab = **262k** (same as
|
||||
Gemini 2.0). **BOS token required** — prepend `[BOS]` / set `add_bos=True`.
|
||||
|
||||
**Gemma 1/2/3 prompt format (different — for reference):**
|
||||
```
|
||||
<start_of_turn>user
|
||||
[message]<end_of_turn>
|
||||
<start_of_turn>model
|
||||
[response]<end_of_turn>
|
||||
```
|
||||
Gemma 1/2/3 have no trained tool-use or thinking tokens. PT models end with
|
||||
`<eos>`; IT models end with `<end_of_turn>`.
|
||||
|
||||
## Gemma 4 variants (canonical spec from model card)
|
||||
|
||||
| Variant | Params | Active | Context | Multimodal |
|
||||
|---------|--------|--------|---------|------------|
|
||||
| Gemma 4 E2B | 2.3B effective (5.1B w/ embeddings), 35 layers | — | 128K | text+image+audio (30s max) |
|
||||
| Gemma 4 E4B | 4.5B effective (8B w/ embeddings), 42 layers | — | 128K | text+image+audio (30s max) |
|
||||
| Gemma 4 26B A4B | 25.2B total (MoE), 30 layers | 3.8B | 256K | text+image |
|
||||
| Gemma 4 31B | 30.7B dense, 60 layers | — | 256K | text+image |
|
||||
|
||||
All variants: Apache 2.0, base + instruction-tuned (`-it`), 140+ languages,
|
||||
native function calling, native structured JSON output. Vision encoder = 150M
|
||||
(E2B/E4B) or 550M (26B/31B). Image resolution token budgets: 70, 140, 280,
|
||||
560, 1120. Released 2026-04-02.
|
||||
|
||||
## Fetched using
|
||||
|
||||
All files fetched via `curl -sL` from `raw.githubusercontent.com` on
|
||||
2026-04-18. Repos enumerated via the GitHub API
|
||||
(`https://api.github.com/repos/<owner>/<repo>/contents/<path>`). Google docs
|
||||
pages fetched via WebFetch tool. No GitHub auth needed for public raw files
|
||||
(unauthenticated rate limit = 60 req/hr, sufficient for this task).
|
||||
@@ -0,0 +1,80 @@
|
||||
|
||||
# Welcome to the Gemma Cookbook
|
||||
This is a collection of guides and examples for [Google Gemma](https://ai.google.dev/gemma/).
|
||||
|
||||
> **Disclaimer:** Gemma is a family of developer-focused models built by Google DeepMind. This cookbook is a collection of guides and examples for Google Gemma. Please keep in mind that Gemma is an open model and can hallucinate as you build on examples in this cookbook.
|
||||
|
||||
## Repository Structure
|
||||
* [**Tutorials**](tutorials/): The latest tested notebooks for Gemma models and variants.
|
||||
* [**Apps**](apps/): Full-stack demos and complex end-to-end use cases.
|
||||
* [**Experiments**](experiments/): Research-focused model notebooks, including [TxGemma](experiments/TxGemma) and [MedGemma](experiments/MedGemma).
|
||||
* [**Responsible**](responsible/): Notebooks for responsible AI development.
|
||||
* [**Docs**](docs/): Core documentation, capabilities, and technical guides.
|
||||
* [**Archive**](.archive/): All older notebooks and historical examples.
|
||||
|
||||
## Get started with the Gemma models
|
||||
Gemma is a family of lightweight, generative artificial intelligence (AI) open models, built from the same research and technology used to create the Gemini models. The Gemma model family includes:
|
||||
* Gemma\
|
||||
The core models of the Gemma family.
|
||||
* [Gemma](https://ai.google.dev/gemma/docs/core/model_card)\
|
||||
For a variety of text generation tasks and can be further tuned for specific use cases
|
||||
* [Gemma 2](https://ai.google.dev/gemma/docs/core/model_card_2)\
|
||||
Higher-performing and more efficient, available in 2B, 9B, 27B parameter sizes
|
||||
* [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3)\
|
||||
Longer context window and handling text and image input, available in 1B, 4B, 12B, and 27B parameter sizes
|
||||
* [Gemma 3n](https://ai.google.dev/gemma/docs/gemma-3n/model_card) \
|
||||
Designed for efficient execution on low-resource devices. Handling text, image, video, and audio input, available in E2B and E4B parameter sizes
|
||||
* [Gemma 4](https://ai.google.dev/gemma/docs/core/model_card_4)\
|
||||
Well-suited for reasoning, agentic workflows, coding, and multimodal understanding, available in E2B, E4B, 26B A4B, and 31B parameter sizes.
|
||||
* Gemma variants
|
||||
* [CodeGemma](https://ai.google.dev/gemma/docs/codegemma)\
|
||||
Fine-tuned for a variety of coding tasks
|
||||
* [DataGemma](https://ai.google.dev/gemma/docs/datagemma)\
|
||||
Fine-tuned for using Data Commons to address AI hallucinations
|
||||
* [FunctionGemma](https://ai.google.dev/gemma/docs/functiongemma)\
|
||||
Fine-tuned on Gemma 3 270M IT checkpoint for function calling
|
||||
* [MedGemma](https://developers.google.com/health-ai-developer-foundations/medgemma)
|
||||
The MedGemma collection contains Google's most capable open models for medical text and image comprehension, built on Gemma 3. Developers can use MedGemma to accelerate building healthcare-based AI applications. MedGemma comes in two variants: a 4B multimodal version and a 27B text-only version.
|
||||
* [PaliGemma](https://ai.google.dev/gemma/docs/paligemma/model-card)\
|
||||
Vision Language Model\
|
||||
For a deeper analysis of images and provide useful insights
|
||||
* [PaliGemma 2](https://ai.google.dev/gemma/docs/paligemma/model-card-2)\
|
||||
VLM which incorporates the capabilities of the Gemma 2 models
|
||||
* [RecurrentGemma](https://ai.google.dev/gemma/docs/recurrentgemma)\
|
||||
Based on [Griffin](https://arxiv.org/abs/2402.19427) architecture\
|
||||
For a variety of text generation tasks
|
||||
* [ShieldGemma](https://ai.google.dev/gemma/docs/shieldgemma/model_card)\
|
||||
Fine-tuned for evaluating the safety of text prompt input and text output responses against a set of defined safety policies
|
||||
* [ShieldGemma 2](https://ai.google.dev/gemma/docs/shieldgemma/model_card_2)\
|
||||
Fine-tuned on Gemma 3 4B IT checkpoint for image safety classification
|
||||
* [T5Gemma](https://deepmind.google/models/gemma/t5gemma)\
|
||||
A collection of encoder-decoder models that provide a strong quality-inference efficiency tradeoff
|
||||
* [TranslateGemma](https://huggingface.co/collections/google/translategemma)\
|
||||
A collection of open model designed to handle translation tasks across 55 languages
|
||||
* [TxGemma](https://deepmind.google/models/gemma/txgemma)\
|
||||
A collection of open models designed to improve the efficiency of therapeutic development
|
||||
* [VaultGemma](https://deepmind.google/models/gemma/vaultgemma)\
|
||||
An open model trained from the ground up using differential privacy to prevent memorization and leaking of training data examples
|
||||
|
||||
You can find the Gemma models on the Hugging Face Hub, Kaggle, Google Cloud Vertex AI Model Garden, and [ai.nvidia.com](https://ai.nvidia.com).
|
||||
|
||||
## Additional Resources
|
||||
* [MedGemma on Google-Health](https://github.com/Google-Health/medgemma/tree/main/notebooks) : Google-Health has additional notebooks for using MedGemma
|
||||
* [Gemma on Google Cloud](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/open-models) : GCP open models has additional notebooks for using Gemma
|
||||
|
||||
## Get help
|
||||
Ask a Gemma cookbook-related question on the [developer forum](https://discuss.ai.google.dev/c/gemma/10), or open an [issue](https://github.com/google-gemini/gemma-cookbook/issues) on GitHub.
|
||||
|
||||
## Wish list
|
||||
If you want to see additional cookbooks implemented for specific features/integrations, please open a new issue with [“Feature Request” template](https://github.com/google-gemini/gemma-cookbook/issues/new?template=feature_request.yml).
|
||||
|
||||
If you want to make contributions to the Gemma Cookbook project, you are welcome to pick any idea in the [“Wish List”](https://github.com/google-gemini/gemma-cookbook/labels/wishlist) and implement it.
|
||||
|
||||
## Contributing
|
||||
Contributions are always welcome. Please read [contributing](https://github.com/google-gemini/gemma-cookbook/blob/main/CONTRIBUTING.md) before implementation.
|
||||
|
||||
Thank you for developing with Gemma! We’re excited to see what you create.
|
||||
|
||||
## Translation of this repository
|
||||
* [Traditional Chinese](https://github.com/doggy8088/gemma-cookbook)
|
||||
* [Simplified Chinese](https://github.com/xiaoxiong1006/gemma-cookbook)
|
||||
@@ -0,0 +1,526 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "colab-badge"
|
||||
},
|
||||
"source": [
|
||||
"<table align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/apps/Gemma_4_HDP_Agentic_Security/Gemma_4_HDP_Agentic_Security.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "byline"
|
||||
},
|
||||
"source": [
|
||||
"# Securing Gemma 4 Agentic Workflows with HDP\n",
|
||||
"\n",
|
||||
"**Author:** Asiri Dalugoda, Helixar Limited ([@asiridalugoda](https://github.com/asiridalugoda)) | [helixar.ai](https://helixar.ai)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gpu-instructions"
|
||||
},
|
||||
"source": [
|
||||
"## Before you begin\n",
|
||||
"\n",
|
||||
"This notebook requires a GPU runtime. To enable GPU in Colab:\n",
|
||||
"1. Go to **Runtime → Change runtime type**\n",
|
||||
"2. Set **Hardware accelerator** to **GPU** (T4 is sufficient for E4B)\n",
|
||||
"3. Click **Save**\n",
|
||||
"\n",
|
||||
"You will also need a **Hugging Face token** to download Gemma 4 (gated model):\n",
|
||||
"1. Go to [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
|
||||
"2. Create a token with **Read** access\n",
|
||||
"3. Accept the Gemma 4 model license at [huggingface.co/google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)\n",
|
||||
"4. Run the cell below to authenticate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hf-login"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "overview"
|
||||
},
|
||||
"source": [
|
||||
"# Securing Gemma 4 Agentic Workflows with HDP\n",
|
||||
"\n",
|
||||
"**Human Delegation Provenance (HDP)** is an open protocol that adds a cryptographic chain-of-custody to AI agent function calls — ensuring every tool invocation can be traced back to an authorized human principal.\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to integrate HDP with Gemma 4's native function-calling capability to:\n",
|
||||
"\n",
|
||||
"- **Verify** that Gemma 4's function calls were authorized by a human principal before execution\n",
|
||||
"- **Classify** actions by irreversibility (read-only → irreversible → physical actuation)\n",
|
||||
"- **Block** unauthorized or out-of-scope tool calls at the middleware layer\n",
|
||||
"- **Audit** every decision with a pre-execution log\n",
|
||||
"\n",
|
||||
"This is particularly relevant for Gemma 4 deployments on edge devices (Jetson Nano, Raspberry Pi) where the model may be directing physical actuators offline with no out-of-band authorization check.\n",
|
||||
"\n",
|
||||
"**References:**\n",
|
||||
"- HDP IETF draft: [draft-helixar-hdp-agentic-delegation-00](https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/)\n",
|
||||
"- HDP-P (physical AI agents): [DOI 10.5281/ZENODO.19332440](https://doi.org/10.5281/ZENODO.19332440)\n",
|
||||
"- Helixar: [helixar.ai](https://helixar.ai)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b3600ee25c8e"
|
||||
},
|
||||
"source": [
|
||||
"## Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "7a80251f52b3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q transformers torch cryptography"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ed80fe18f255"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Download the middleware\n",
|
||||
"!wget -q https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/apps/Gemma_4_HDP_Agentic_Security/hdp_middleware.py\n",
|
||||
"\n",
|
||||
"from hdp_middleware import (\n",
|
||||
" HDPDelegationToken,\n",
|
||||
" HDPMiddleware,\n",
|
||||
" IrreversibilityClass,\n",
|
||||
" DEFAULT_TOOL_CLASS_MAP,\n",
|
||||
")\n",
|
||||
"from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey\n",
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e88bdc7b7265"
|
||||
},
|
||||
"source": [
|
||||
"## 1. Load Gemma 4\n",
|
||||
"\n",
|
||||
"We use the 4B Effective model for this demo. For production agentic deployments, the 26B MoE or 31B Dense models are recommended."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1e4e7779806d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"# For edge/robotics use cases: swap to google/gemma-4-E2B-it\n",
|
||||
"MODEL_ID = \"google/gemma-4-E4B-it\"\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\n",
|
||||
" \"text-generation\",\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "d91e36cfb0b2"
|
||||
},
|
||||
"source": [
|
||||
"## 2. Define Tools\n",
|
||||
"\n",
|
||||
"Gemma 4 uses structured JSON function-calling. We define a tool set spanning different IrreversibilityClasses to demonstrate the middleware's classification behaviour."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1becdb52e7f8"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TOOLS = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"get_weather\",\n",
|
||||
" \"description\": \"Get the current weather for a location.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\"type\": \"string\", \"description\": \"City name\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"send_email\",\n",
|
||||
" \"description\": \"Send an email to a recipient.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"to\": {\"type\": \"string\"},\n",
|
||||
" \"subject\": {\"type\": \"string\"},\n",
|
||||
" \"body\": {\"type\": \"string\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"to\", \"subject\", \"body\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"delete_file\",\n",
|
||||
" \"description\": \"Permanently delete a file by path.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"path\": {\"type\": \"string\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"path\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"actuate_robot_arm\",\n",
|
||||
" \"description\": \"Command a robot arm to move to a target position.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"joint_angles\": {\"type\": \"array\", \"items\": {\"type\": \"number\"}},\n",
|
||||
" \"force_limit_n\": {\"type\": \"number\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"joint_angles\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Tools indexed by name for lookup\n",
|
||||
"TOOL_REGISTRY = {t[\"name\"]: t for t in TOOLS}\n",
|
||||
"print(f\"Registered {len(TOOLS)} tools\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "722948b00a92"
|
||||
},
|
||||
"source": [
|
||||
"## 3. Issue an HDP Delegation Token\n",
|
||||
"\n",
|
||||
"The human principal generates an Ed25519 keypair and issues an HDT that specifies:\n",
|
||||
"- Which tools the agent is permitted to call\n",
|
||||
"- The maximum IrreversibilityClass the agent can act on\n",
|
||||
"- The token's lifetime"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "b0622c68dfa5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Human principal generates their signing keypair\n",
|
||||
"# In production: loaded from secure key storage (HSM, OS keychain, etc.)\n",
|
||||
"principal_private_key = Ed25519PrivateKey.generate()\n",
|
||||
"principal_public_key = principal_private_key.public_key()\n",
|
||||
"\n",
|
||||
"# Issue an HDT authorizing the Gemma 4 agent to call weather queries\n",
|
||||
"# and send emails (Class 0 and Class 2), but NOT delete files or actuate hardware\n",
|
||||
"token = HDPDelegationToken.issue(\n",
|
||||
" principal_id=\"alice@example.com\",\n",
|
||||
" agent_id=\"gemma4-agent-01\",\n",
|
||||
" scope=[\"get_weather\", \"send_email\"],\n",
|
||||
" max_class=IrreversibilityClass.CLASS_2,\n",
|
||||
" ttl_seconds=3600,\n",
|
||||
" private_key=principal_private_key,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(json.dumps(token.to_dict(), indent=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e206f950f4bc"
|
||||
},
|
||||
"source": [
|
||||
"## 4. Initialise the HDP Middleware\n",
|
||||
"\n",
|
||||
"The middleware takes the principal's **public key** only — it verifies but cannot issue tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "e24676f528bf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"audit_log = []\n",
|
||||
"\n",
|
||||
"# Confirmation callback for Class 2 (irreversible) actions.\n",
|
||||
"# In production: this would invoke a push notification, SMS OTP,\n",
|
||||
"# or hardware confirmation device to the human principal.\n",
|
||||
"def require_human_confirmation(tool_name: str, parameters: dict) -> bool:\n",
|
||||
" print(f\"\\n⚠️ Class 2 action requested: {tool_name}\")\n",
|
||||
" print(f\" Parameters: {json.dumps(parameters, indent=4)}\")\n",
|
||||
" response = input(\" Confirm? [y/N]: \").strip().lower()\n",
|
||||
" return response == \"y\"\n",
|
||||
"\n",
|
||||
"middleware = HDPMiddleware(\n",
|
||||
" public_key=principal_public_key,\n",
|
||||
" tool_class_map=DEFAULT_TOOL_CLASS_MAP,\n",
|
||||
" confirmation_callback=require_human_confirmation,\n",
|
||||
" audit_log=audit_log,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"HDP middleware initialised.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "72d56542eba0"
|
||||
},
|
||||
"source": [
|
||||
"## 5. Gemma 4 Function Call → HDP Gate → Tool Execution\n",
|
||||
"\n",
|
||||
"This is the core integration pattern. Every function call Gemma 4 generates is passed through `middleware.gate()` before being forwarded to tool execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "da20bc191e71"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Simulated Gemma 4 function call outputs\n",
|
||||
"# In production these come from parsing Gemma 4's structured JSON output\n",
|
||||
"gemma_function_calls = [\n",
|
||||
" # ✅ Should ALLOW — Class 0, in scope\n",
|
||||
" {\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}},\n",
|
||||
"\n",
|
||||
" # ⚠️ Should CONFIRM then ALLOW — Class 2, in scope\n",
|
||||
" {\"name\": \"send_email\", \"parameters\": {\n",
|
||||
" \"to\": \"bob@example.com\",\n",
|
||||
" \"subject\": \"Weekly report\",\n",
|
||||
" \"body\": \"Please find attached.\"\n",
|
||||
" }},\n",
|
||||
"\n",
|
||||
" # ❌ Should BLOCK — Class 2, NOT in HDT scope\n",
|
||||
" {\"name\": \"delete_file\", \"parameters\": {\"path\": \"/data/important.csv\"}},\n",
|
||||
"\n",
|
||||
" # ❌ Should BLOCK — Class 3, physical actuation\n",
|
||||
" {\"name\": \"actuate_robot_arm\", \"parameters\": {\n",
|
||||
" \"joint_angles\": [0.0, -1.57, 0.0, -1.57, 0.0, 0.0],\n",
|
||||
" \"force_limit_n\": 50.0\n",
|
||||
" }},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(\"HDP VERIFICATION RESULTS\")\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
"for call in gemma_function_calls:\n",
|
||||
" result = middleware.gate(call, token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "be0d0dd05bce"
|
||||
},
|
||||
"source": [
|
||||
"## 6. Audit Log\n",
|
||||
"\n",
|
||||
"Every decision is logged pre-execution. This is the HDP audit trail — a cryptographically linked record of what was authorized, by whom, and when."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "e6dbab6d88d1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"\\nAUDIT LOG\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"for i, entry in enumerate(audit_log):\n",
|
||||
" status = \"✅ ALLOWED\" if entry.allowed else \"❌ BLOCKED\"\n",
|
||||
" print(f\"{i+1}. {status} | {entry.tool_name} | {entry.action_class.name} | {entry.reason}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bcadcb7040db"
|
||||
},
|
||||
"source": [
|
||||
"## 7. Token Expiry and Scope Violation Demo\n",
|
||||
"\n",
|
||||
"Demonstrate that expired tokens and out-of-scope calls are blocked regardless of the action class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "deb2e3b6b20e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# Issue a token that's already expired\n",
|
||||
"expired_token = HDPDelegationToken.issue(\n",
|
||||
" principal_id=\"alice@example.com\",\n",
|
||||
" agent_id=\"gemma4-agent-01\",\n",
|
||||
" scope=[\"get_weather\"],\n",
|
||||
" max_class=IrreversibilityClass.CLASS_0,\n",
|
||||
" ttl_seconds=-1, # expired immediately\n",
|
||||
" private_key=principal_private_key,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Testing expired token:\")\n",
|
||||
"middleware.gate({\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}}, expired_token)\n",
|
||||
"\n",
|
||||
"print(\"\\nTesting call outside HDT scope:\")\n",
|
||||
"middleware.gate({\"name\": \"delete_file\", \"parameters\": {\"path\": \"/etc/passwd\"}}, token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b8f4acddb6fa"
|
||||
},
|
||||
"source": [
|
||||
"## 8. Edge / Robotics Deployment (HDP-P)\n",
|
||||
"\n",
|
||||
"For Gemma 4 E2B/E4B running on Jetson Nano or Raspberry Pi and directing physical actuators, use the HDP-P extension. The key additions are:\n",
|
||||
"\n",
|
||||
"- **Embodiment context** — bind the token to a specific hardware ID\n",
|
||||
"- **Policy attestation** — hash the deployed model weights into the token\n",
|
||||
"- **Fleet delegation constraints** — prevent lateral movement across robot fleet\n",
|
||||
"- **Pre-execution logging** — write audit records *before* actuator commands are issued\n",
|
||||
"\n",
|
||||
"See the [HDP-P specification](https://doi.org/10.5281/ZENODO.19332440) for the full EDT extension structure."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fcf7b451d175"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Minimal HDP-P Embodied Delegation Token (EDT) extension example\n",
|
||||
"# This shows how to attach physical constraints to an HDT\n",
|
||||
"\n",
|
||||
"hdp_p_extension = {\n",
|
||||
" \"hdp-p\": {\n",
|
||||
" \"version\": \"0.1\",\n",
|
||||
" \"embodiment\": {\n",
|
||||
" \"type\": \"mobile\",\n",
|
||||
" \"platform\": \"raspberry-pi-5\",\n",
|
||||
" \"hardware_id\": \"rpi-serial-XXXX\", # TPM-attested in production\n",
|
||||
" \"workspace\": \"lab-zone-a\"\n",
|
||||
" },\n",
|
||||
" \"action_scope\": {\n",
|
||||
" \"permitted_actions\": [\"move_base\", \"read_sensor\"],\n",
|
||||
" \"excluded_zones\": [\"human-workspace\"],\n",
|
||||
" \"force_limit_n\": 10.0,\n",
|
||||
" \"max_velocity_ms\": 0.5\n",
|
||||
" },\n",
|
||||
" \"irreversibility\": {\n",
|
||||
" \"max_class\": 1, # Class 1 max for this token\n",
|
||||
" \"class2_requires_confirmation\": True,\n",
|
||||
" \"class3_prohibited\": True\n",
|
||||
" },\n",
|
||||
" \"policy_attestation\": {\n",
|
||||
" \"policy_hash\": \"sha256:abc123...\", # SHA-256 of deployed model weights\n",
|
||||
" \"training_run_id\": \"gemma4-e2b-it\",\n",
|
||||
" \"sim_validated\": True\n",
|
||||
" },\n",
|
||||
" \"delegation_scope\": {\n",
|
||||
" \"fleet_delegation_permitted\": False, # No lateral movement\n",
|
||||
" \"max_delegation_depth\": 0\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"print(\"HDP-P EDT extension structure:\")\n",
|
||||
"print(json.dumps(hdp_p_extension, indent=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b0af7c701dfc"
|
||||
},
|
||||
"source": [
|
||||
"## Summary\n",
|
||||
"\n",
|
||||
"| Layer | What it solves | Tool |\n",
|
||||
"|---|---|---|\n",
|
||||
"| Gemma 4 function calling | Model generates structured tool calls | `pipeline(\"text-generation\")` |\n",
|
||||
"| HDP middleware | Was this call authorized by a human? | `HDPMiddleware.gate()` |\n",
|
||||
"| HDP-P EDT extension | Is this physical action within delegated bounds? | `hdp_p_extension` |\n",
|
||||
"| Audit log | Pre-execution record of every decision | `audit_log` |\n",
|
||||
"\n",
|
||||
"The full HDP specification (IETF draft), HDP-P companion paper, TypeScript SDK, and Python bindings are available at:\n",
|
||||
"\n",
|
||||
"- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/\n",
|
||||
"- **HDP-P paper:** https://doi.org/10.5281/ZENODO.19332440\n",
|
||||
"- **GitHub:** https://github.com/Helixar-AI\n",
|
||||
"- **Site:** https://helixar.ai"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "Gemma_4_HDP_Agentic_Security.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
# Gemma 4 + HDP: Securing Agentic Function Calls
|
||||
|
||||
This example demonstrates how to integrate the **Human Delegation Provenance (HDP)** protocol with **Gemma 4's native function-calling** to cryptographically verify that every tool invocation was authorized by a human principal before execution.
|
||||
|
||||
## The problem
|
||||
|
||||
Gemma 4 is purpose-built for agentic workflows. Its native function-calling lets it autonomously call tools and APIs across multi-step plans — on anything from a cloud workstation to a Raspberry Pi running a robot offline.
|
||||
|
||||
This creates a gap: when Gemma 4 generates a function call, there is no verifiable record that a human principal authorized that specific action. An injected prompt, a compromised system prompt, or a lateral pivot from another agent can trigger function calls that are indistinguishable from legitimate requests at the tool interface.
|
||||
|
||||
HDP closes this gap.
|
||||
|
||||
## What HDP does
|
||||
|
||||
HDP (IETF draft: `draft-helixar-hdp-agentic-delegation-00`) provides:
|
||||
|
||||
- **Ed25519-signed Delegation Tokens (HDTs)** issued by a human principal
|
||||
- **Scope constraints** — which tools the agent is permitted to call
|
||||
- **Irreversibility classification** (Class 0–3) — from read-only to physical actuation
|
||||
- **Pre-execution verification** — the middleware gate runs *before* any tool executes
|
||||
- **Audit log** — a tamper-evident record of every authorization decision
|
||||
|
||||
For Gemma 4 on **edge devices directing physical actuators** (Jetson Nano, Raspberry Pi + robot arm), the HDP-P companion specification adds embodiment constraints, policy attestation, and fleet delegation controls.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| `Gemma_4_HDP_Agentic_Security.ipynb` | Full walkthrough notebook — load Gemma 4, issue tokens, gate function calls |
|
||||
| `hdp_middleware.py` | Drop-in middleware — `HDPMiddleware.gate()` wraps any Gemma 4 tool executor |
|
||||
|
||||
## Quick start
|
||||
|
||||
```python
|
||||
from hdp_middleware import HDPDelegationToken, HDPMiddleware, IrreversibilityClass
|
||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
|
||||
|
||||
# Human principal issues a delegation token
|
||||
private_key = Ed25519PrivateKey.generate()
|
||||
token = HDPDelegationToken.issue(
|
||||
principal_id="alice@example.com",
|
||||
agent_id="gemma4-agent-01",
|
||||
scope=["get_weather", "send_email"],
|
||||
max_class=IrreversibilityClass.CLASS_2,
|
||||
ttl_seconds=3600,
|
||||
private_key=private_key,
|
||||
)
|
||||
|
||||
# Middleware verifies every Gemma 4 function call before execution
|
||||
middleware = HDPMiddleware(public_key=private_key.public_key())
|
||||
|
||||
result = middleware.gate(
|
||||
function_call={"name": "send_email", "parameters": {"to": "bob@example.com", ...}},
|
||||
token=token,
|
||||
)
|
||||
|
||||
if result.allowed:
|
||||
execute_tool(function_call)
|
||||
```
|
||||
|
||||
## Irreversibility classes
|
||||
|
||||
| Class | Definition | Authorization |
|
||||
|---|---|---|
|
||||
| 0 | Fully reversible — reads, queries | HDT sufficient |
|
||||
| 1 | Reversible with effort — writes, moves | HDT sufficient |
|
||||
| 2 | Irreversible — send, delete, publish | HDT + principal confirmation |
|
||||
| 3 | Irreversible + potentially harmful — physical actuation | Dual-principal required (HDP-P) |
|
||||
|
||||
## References
|
||||
|
||||
- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/
|
||||
- **Zenodo DOI:** https://doi.org/10.5281/zenodo.19332023
|
||||
- **HDP-P (physical AI):** https://doi.org/10.5281/ZENODO.19332440
|
||||
- **Helixar:** https://helixar.ai
|
||||
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
HDP (Human Delegation Provenance) middleware for Gemma 4 function calling.
|
||||
|
||||
Intercepts Gemma 4 function call outputs and verifies that a valid HDP
|
||||
Delegation Token (HDT) authorizes the requested action before forwarding
|
||||
to the tool execution layer.
|
||||
|
||||
Reference: draft-helixar-hdp-agentic-delegation-00
|
||||
https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/
|
||||
DOI: 10.5281/zenodo.19332023
|
||||
|
||||
For physical AI agents (robots, edge devices), see HDP-P:
|
||||
DOI: 10.5281/ZENODO.19332440
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Callable, Any
|
||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
|
||||
Ed25519PrivateKey,
|
||||
Ed25519PublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
Encoding,
|
||||
PublicFormat,
|
||||
PrivateFormat,
|
||||
NoEncryption,
|
||||
)
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Irreversibility Classes (HDP-P §4.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class IrreversibilityClass(IntEnum):
|
||||
"""
|
||||
Classification of physical action reversibility (HDP-P §4.2).
|
||||
|
||||
For digital-only Gemma 4 deployments, all tool calls are Class 0 or 1.
|
||||
For edge/robotics deployments (Jetson Nano, Raspberry Pi + actuators),
|
||||
Class 2 and 3 require explicit pre-execution confirmation.
|
||||
"""
|
||||
CLASS_0 = 0 # Fully reversible — read-only, query, observe
|
||||
CLASS_1 = 1 # Reversible with effort — write, create, move
|
||||
CLASS_2 = 2 # Irreversible under normal conditions — delete, send, publish
|
||||
CLASS_3 = 3 # Irreversible and potentially harmful — physical actuation
|
||||
|
||||
|
||||
# Default tool → irreversibility class mapping.
|
||||
# Deployments should override this for their specific tool set.
|
||||
DEFAULT_TOOL_CLASS_MAP: dict[str, IrreversibilityClass] = {
|
||||
# Class 0 — safe reads
|
||||
"get_weather": IrreversibilityClass.CLASS_0,
|
||||
"search_web": IrreversibilityClass.CLASS_0,
|
||||
"read_file": IrreversibilityClass.CLASS_0,
|
||||
"query_database": IrreversibilityClass.CLASS_0,
|
||||
# Class 1 — reversible writes
|
||||
"write_file": IrreversibilityClass.CLASS_1,
|
||||
"create_record": IrreversibilityClass.CLASS_1,
|
||||
"move_object": IrreversibilityClass.CLASS_1,
|
||||
# Class 2 — irreversible digital actions
|
||||
"send_email": IrreversibilityClass.CLASS_2,
|
||||
"delete_file": IrreversibilityClass.CLASS_2,
|
||||
"publish_post": IrreversibilityClass.CLASS_2,
|
||||
"execute_transaction": IrreversibilityClass.CLASS_2,
|
||||
# Class 3 — physical actuation (HDP-P scope)
|
||||
"actuate_robot_arm": IrreversibilityClass.CLASS_3,
|
||||
"command_vehicle": IrreversibilityClass.CLASS_3,
|
||||
"dispense_fluid": IrreversibilityClass.CLASS_3,
|
||||
"apply_force": IrreversibilityClass.CLASS_3,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HDP Delegation Token (HDT)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class HDPDelegationToken:
|
||||
"""
|
||||
Simplified HDT structure derived from draft-helixar-hdp-agentic-delegation-00.
|
||||
|
||||
In production, HDTs are JOSE/JWT tokens signed with Ed25519.
|
||||
This implementation provides the core claims structure and verification logic.
|
||||
|
||||
Claims:
|
||||
iss — issuer (human principal identifier)
|
||||
sub — subject (agent being delegated to)
|
||||
iat — issued at (unix timestamp)
|
||||
exp — expiry (unix timestamp)
|
||||
scope — list of permitted tool names or wildcard patterns
|
||||
max_irreversibility_class — ceiling on action class (0–3)
|
||||
delegation_depth — remaining delegation hops permitted
|
||||
nonce — replay-attack prevention
|
||||
"""
|
||||
iss: str
|
||||
sub: str
|
||||
iat: int
|
||||
exp: int
|
||||
scope: list[str]
|
||||
max_irreversibility_class: IrreversibilityClass
|
||||
delegation_depth: int = 1
|
||||
nonce: str = ""
|
||||
_signature: bytes = field(default=b"", repr=False)
|
||||
_public_key: Optional[Ed25519PublicKey] = field(default=None, repr=False)
|
||||
|
||||
@classmethod
|
||||
def issue(
|
||||
cls,
|
||||
principal_id: str,
|
||||
agent_id: str,
|
||||
scope: list[str],
|
||||
max_class: IrreversibilityClass,
|
||||
ttl_seconds: int = 3600,
|
||||
delegation_depth: int = 1,
|
||||
private_key: Optional[Ed25519PrivateKey] = None,
|
||||
) -> "HDPDelegationToken":
|
||||
"""
|
||||
Issue a new HDT signed by the human principal's Ed25519 private key.
|
||||
|
||||
Args:
|
||||
principal_id: Human principal identifier (e.g. "alice@example.com")
|
||||
agent_id: Agent being delegated to (e.g. "gemma4-agent-01")
|
||||
scope: List of permitted tool names. Use ["*"] for unrestricted.
|
||||
max_class: Maximum IrreversibilityClass this token permits.
|
||||
ttl_seconds: Token lifetime in seconds.
|
||||
delegation_depth: How many times this token can be re-delegated.
|
||||
private_key: Ed25519 private key for signing. Generated if None.
|
||||
"""
|
||||
now = int(time.time())
|
||||
nonce = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(f"{principal_id}{now}".encode()).digest()[:16]
|
||||
).decode()
|
||||
|
||||
token = cls(
|
||||
iss=principal_id,
|
||||
sub=agent_id,
|
||||
iat=now,
|
||||
exp=now + ttl_seconds,
|
||||
scope=scope,
|
||||
max_irreversibility_class=max_class,
|
||||
delegation_depth=delegation_depth,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
if private_key is None:
|
||||
private_key = Ed25519PrivateKey.generate()
|
||||
|
||||
token._public_key = private_key.public_key()
|
||||
token._signature = private_key.sign(token._canonical_bytes())
|
||||
return token
|
||||
|
||||
def _canonical_bytes(self) -> bytes:
|
||||
"""Deterministic serialisation for signing/verification."""
|
||||
payload = {
|
||||
"iss": self.iss,
|
||||
"sub": self.sub,
|
||||
"iat": self.iat,
|
||||
"exp": self.exp,
|
||||
"scope": sorted(self.scope),
|
||||
"max_irreversibility_class": int(self.max_irreversibility_class),
|
||||
"delegation_depth": self.delegation_depth,
|
||||
"nonce": self.nonce,
|
||||
}
|
||||
return json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
|
||||
|
||||
def verify(self, public_key: Ed25519PublicKey) -> bool:
|
||||
"""Verify the token's Ed25519 signature."""
|
||||
try:
|
||||
public_key.verify(self._signature, self._canonical_bytes())
|
||||
return True
|
||||
except InvalidSignature:
|
||||
return False
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return int(time.time()) > self.exp
|
||||
|
||||
def permits_tool(self, tool_name: str) -> bool:
|
||||
"""Check whether this token's scope covers the requested tool."""
|
||||
if "*" in self.scope:
|
||||
return True
|
||||
return tool_name in self.scope
|
||||
|
||||
def permits_class(self, action_class: IrreversibilityClass) -> bool:
|
||||
return action_class <= self.max_irreversibility_class
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"iss": self.iss,
|
||||
"sub": self.sub,
|
||||
"iat": self.iat,
|
||||
"exp": self.exp,
|
||||
"scope": self.scope,
|
||||
"max_irreversibility_class": int(self.max_irreversibility_class),
|
||||
"delegation_depth": self.delegation_depth,
|
||||
"nonce": self.nonce,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Verification result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
allowed: bool
|
||||
reason: str
|
||||
tool_name: str
|
||||
action_class: IrreversibilityClass
|
||||
token_iss: Optional[str] = None
|
||||
requires_confirmation: bool = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
status = "ALLOWED" if self.allowed else "BLOCKED"
|
||||
conf = " [CONFIRMATION REQUIRED]" if self.requires_confirmation else ""
|
||||
return (
|
||||
f"[HDP] {status}{conf} — tool={self.tool_name} "
|
||||
f"class={self.action_class.name} reason={self.reason}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HDP Middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HDPMiddleware:
|
||||
"""
|
||||
HDP verification gate for Gemma 4 function calls.
|
||||
|
||||
Sits between Gemma 4's function-call output and the tool execution layer.
|
||||
For each function call Gemma 4 generates, this middleware:
|
||||
|
||||
1. Parses the tool name from the function call.
|
||||
2. Looks up its IrreversibilityClass.
|
||||
3. Verifies the attached HDT (signature, expiry, scope, class ceiling).
|
||||
4. For Class 2 actions, invokes the confirmation callback.
|
||||
5. Blocks Class 3 actions unless explicitly pre-authorized with
|
||||
dual verification (HDP-P §5.4).
|
||||
6. Logs all decisions before forwarding or blocking.
|
||||
|
||||
Usage:
|
||||
middleware = HDPMiddleware(
|
||||
public_key=principal_public_key,
|
||||
tool_class_map=DEFAULT_TOOL_CLASS_MAP,
|
||||
confirmation_callback=my_confirmation_fn,
|
||||
)
|
||||
|
||||
# Wrap your tool executor:
|
||||
result = middleware.gate(
|
||||
function_call=gemma_output, # {"name": "...", "parameters": {...}}
|
||||
token=hdp_token,
|
||||
)
|
||||
|
||||
if result.allowed:
|
||||
output = execute_tool(function_call)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
public_key: Ed25519PublicKey,
|
||||
tool_class_map: dict[str, IrreversibilityClass] = None,
|
||||
confirmation_callback: Optional[Callable[[str, dict], bool]] = None,
|
||||
default_class: IrreversibilityClass = IrreversibilityClass.CLASS_1,
|
||||
audit_log: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
public_key: Principal's Ed25519 public key for HDT verification.
|
||||
tool_class_map: Mapping of tool names to IrreversibilityClass.
|
||||
Defaults to DEFAULT_TOOL_CLASS_MAP.
|
||||
confirmation_callback: Called for Class 2 actions. Receives
|
||||
(tool_name, parameters) and returns bool.
|
||||
If None, Class 2 actions are blocked.
|
||||
default_class: Class assigned to unknown tools. Defaults to CLASS_1.
|
||||
audit_log: Optional list to append VerificationResult records to.
|
||||
"""
|
||||
self.public_key = public_key
|
||||
self.tool_class_map = tool_class_map or DEFAULT_TOOL_CLASS_MAP
|
||||
self.confirmation_callback = confirmation_callback
|
||||
self.default_class = default_class
|
||||
self.audit_log = audit_log if audit_log is not None else []
|
||||
|
||||
def classify(self, tool_name: str) -> IrreversibilityClass:
|
||||
"""Return the IrreversibilityClass for a tool name."""
|
||||
return self.tool_class_map.get(tool_name, self.default_class)
|
||||
|
||||
def gate(
|
||||
self,
|
||||
function_call: dict,
|
||||
token: HDPDelegationToken,
|
||||
) -> VerificationResult:
|
||||
"""
|
||||
Main verification gate. Call this for every Gemma 4 function call.
|
||||
|
||||
Args:
|
||||
function_call: Gemma 4 function call dict:
|
||||
{"name": "tool_name", "parameters": {...}}
|
||||
token: HDPDelegationToken issued by the human principal.
|
||||
|
||||
Returns:
|
||||
VerificationResult — check .allowed before executing the tool.
|
||||
"""
|
||||
tool_name = function_call.get("name", "")
|
||||
parameters = function_call.get("parameters", {})
|
||||
action_class = self.classify(tool_name)
|
||||
|
||||
def _block(reason: str) -> VerificationResult:
|
||||
result = VerificationResult(
|
||||
allowed=False,
|
||||
reason=reason,
|
||||
tool_name=tool_name,
|
||||
action_class=action_class,
|
||||
token_iss=token.iss if token else None,
|
||||
)
|
||||
self.audit_log.append(result)
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def _allow(reason: str, requires_confirmation: bool = False) -> VerificationResult:
|
||||
result = VerificationResult(
|
||||
allowed=True,
|
||||
reason=reason,
|
||||
tool_name=tool_name,
|
||||
action_class=action_class,
|
||||
token_iss=token.iss,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
self.audit_log.append(result)
|
||||
print(result)
|
||||
return result
|
||||
|
||||
# ── 1. Token presence ───────────────────────────────────────────────
|
||||
if token is None:
|
||||
return _block("no HDT present")
|
||||
|
||||
# ── 2. Expiry ───────────────────────────────────────────────────────
|
||||
if token.is_expired():
|
||||
return _block("HDT expired")
|
||||
|
||||
# ── 3. Signature ────────────────────────────────────────────────────
|
||||
if not token.verify(self.public_key):
|
||||
return _block("HDT signature invalid")
|
||||
|
||||
# ── 4. Scope ────────────────────────────────────────────────────────
|
||||
if not token.permits_tool(tool_name):
|
||||
return _block(f"tool '{tool_name}' not in HDT scope")
|
||||
|
||||
# ── 5. Irreversibility class ceiling ────────────────────────────────
|
||||
if not token.permits_class(action_class):
|
||||
return _block(
|
||||
f"action class {action_class.name} exceeds HDT ceiling "
|
||||
f"{token.max_irreversibility_class.name}"
|
||||
)
|
||||
|
||||
# ── 6. Class 3 — always blocked without explicit dual verification ──
|
||||
if action_class == IrreversibilityClass.CLASS_3:
|
||||
# In production: implement dual-principal confirmation (HDP-P §5.4)
|
||||
return _block(
|
||||
"Class 3 physical action requires dual-principal confirmation "
|
||||
"(HDP-P §5.4) — not implemented in this middleware instance"
|
||||
)
|
||||
|
||||
# ── 7. Class 2 — confirmation callback required ─────────────────────
|
||||
if action_class == IrreversibilityClass.CLASS_2:
|
||||
if self.confirmation_callback is None:
|
||||
return _block(
|
||||
"Class 2 action requires confirmation callback — "
|
||||
"none configured"
|
||||
)
|
||||
confirmed = self.confirmation_callback(tool_name, parameters)
|
||||
if not confirmed:
|
||||
return _block("Class 2 action — confirmation denied by principal")
|
||||
return _allow("Class 2 confirmed by principal", requires_confirmation=True)
|
||||
|
||||
# ── 8. Class 0 / 1 — allow ─────────────────────────────────────────
|
||||
return _allow(f"HDT valid, scope and class verified")
|
||||
|
||||
def gate_batch(
|
||||
self,
|
||||
function_calls: list[dict],
|
||||
token: HDPDelegationToken,
|
||||
) -> list[VerificationResult]:
|
||||
"""Verify a list of function calls. Returns one result per call."""
|
||||
return [self.gate(fc, token) for fc in function_calls]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,925 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-u7xRR3DeFXz"
|
||||
},
|
||||
"source": [
|
||||
"##### Copyright 2026 Google LLC."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "oed1Dh9SeIlD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "A0UbyyBOeKmV"
|
||||
},
|
||||
"source": [
|
||||
"# RAG with EmbeddingGemma\n",
|
||||
"\n",
|
||||
"<table align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/tutorials/RAG_with_EmbeddingGemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ND35JUp9ecq2"
|
||||
},
|
||||
"source": [
|
||||
"EmbeddingGemma is a lightweight, open embedding model designed for fast, high-quality retrieval on everyday devices like mobile phones. At only 308 million parameters, it's efficient enough to run advanced AI techniques, such as Retrieval Augmented Generation (RAG), directly on your local machine with no internet connection required.\n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"Before starting this tutorial, complete the following steps:\n",
|
||||
"\n",
|
||||
"* Get access to EmbeddingGemma by logging into [Hugging Face](https://huggingface.co/google/embeddinggemma-300M) and selecting **Acknowledge license** for a Gemma model.\n",
|
||||
"* Select a Colab runtime with sufficient resources to run\n",
|
||||
" the Gemma model size you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n",
|
||||
"* Generate a Hugging Face [Access Token](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-token) and use it to login from Colab.\n",
|
||||
"\n",
|
||||
"This notebook will run on an NVIDIA T4 GPU."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SZ8cw1nPf-NV"
|
||||
},
|
||||
"source": [
|
||||
"### Install Python packages\n",
|
||||
"\n",
|
||||
"Install the libraries required for running the EmbeddingGemma model and generating embeddings. Sentence Transformers is a Python framework for text and image embeddings. For more information, see the [Sentence Transformers](https://www.sbert.net/) documentation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "daXx6O20Q7M0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q -U sentence-transformers transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kYiTsNFSjGJH"
|
||||
},
|
||||
"source": [
|
||||
"After you have accepted the license, you need a valid Hugging Face Token to access the model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "eLagJ9aff9Ks"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Login into Hugging Face Hub\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "IiDcW_rmHBfx"
|
||||
},
|
||||
"source": [
|
||||
"### Load language model\n",
|
||||
"\n",
|
||||
"You will use Gemma 4 E2B to generate responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"id": "HX2JFDQI-vg8"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c0b54b8b91da46fdb7ba8fd3aecb5002",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4291694230e74608a2808adde451bd0f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model.safetensors: 0%| | 0.00/10.2G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "cb31547f287441aba370d8e7a5fc351e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/1951 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0900cc228bed472094eb986719edfde4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"generation_config.json: 0%| | 0.00/208 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3d195cea1ce044f4827cf06412aed5ec",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3bdb49b389aa4abfbb382fccaceb32be",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer.json: 0%| | 0.00/32.2M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "93e44e5dd0fe40d49e0cda367d98aeca",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"chat_template.jinja: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load Gemma\n",
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"MODEL_ID = \"google/gemma-4-E2B-it\"\n",
|
||||
"\n",
|
||||
"pipeline = pipeline(\n",
|
||||
" task=\"text-generation\",\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" dtype=\"auto\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eAg-c23Wh0th"
|
||||
},
|
||||
"source": [
|
||||
"### Load embedding model\n",
|
||||
"\n",
|
||||
"Use the `sentence-transformers` libraries to create an instance of a model class with EmbeddingGemma."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"id": "6Jj1WiTSRRk-"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2c5dc65f501e402fb5ec67d094d925e7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"modules.json: 0%| | 0.00/573 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "10b836de41a0410d8963be637ffa6b9d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config_sentence_transformers.json: 0%| | 0.00/997 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "68e29095344e4d24ac3898638f5a2b0e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"README.md: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "376438be53e14e4b808ce63de0d32cb2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"sentence_bert_config.json: 0%| | 0.00/58.0 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7f2a5a56690e4ed5950ad0c278cc20c7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "264f0c21602640bd9ddfa9d405b5613f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model.safetensors: 0%| | 0.00/1.21G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "70eb603cffa948cc895046a8238abbae",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/314 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "aa608efe38f448898f8a01940a3684df",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b12b1756d9ac4145ae70595454e0e036",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer.json: 0%| | 0.00/33.4M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e6e735942c07444ebfcf2702673762b6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"added_tokens.json: 0%| | 0.00/35.0 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ff92bb744fd54211b20f04aedebaa26d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"special_tokens_map.json: 0%| | 0.00/662 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "114a2560d2124889932f1a6436c4d6ef",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0%| | 0.00/312 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "292f471e215d4ac8a490508ce6963b01",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0%| | 0.00/134 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "dce2f7bc57134d0180f3accdec8d5556",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"2_Dense/model.safetensors: 0%| | 0.00/9.44M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c245d417dc9f4d71850853a107379b16",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0%| | 0.00/134 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "18def66743ae4738b940a4b20c434545",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"3_Dense/model.safetensors: 0%| | 0.00/9.44M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Device: cuda:0\n",
|
||||
"SentenceTransformer(\n",
|
||||
" (0): Transformer({'transformer_task': 'feature-extraction', 'modality_config': {'text': {'method': 'forward', 'method_output_name': 'last_hidden_state'}}, 'module_output_name': 'token_embeddings', 'architecture': 'Gemma3TextModel'})\n",
|
||||
" (1): Pooling({'embedding_dimension': 768, 'pooling_mode': 'mean', 'include_prompt': True})\n",
|
||||
" (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity', 'module_input_name': 'sentence_embedding', 'module_output_name': 'sentence_embedding'})\n",
|
||||
" (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity', 'module_input_name': 'sentence_embedding', 'module_output_name': 'sentence_embedding'})\n",
|
||||
" (4): Normalize({})\n",
|
||||
")\n",
|
||||
"Total number of parameters in the model: 307581696\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"\n",
|
||||
"model_id = \"google/embeddinggemma-300M\"\n",
|
||||
"model = SentenceTransformer(model_id).to(device=device)\n",
|
||||
"\n",
|
||||
"print(f\"Device: {model.device}\")\n",
|
||||
"print(model)\n",
|
||||
"print(\"Total number of parameters in the model:\", sum([p.numel() for _, p in model.named_parameters()]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8o2-nOX-aqRS"
|
||||
},
|
||||
"source": [
|
||||
"### Using Prompts with EmbeddingGemma\n",
|
||||
"\n",
|
||||
"For RAG systems, use the following `prompt_name` values to create specialized embeddings for your queries and documents:\n",
|
||||
"\n",
|
||||
"* **For Queries:** Use `prompt_name=\"Retrieval-query\"`.<br>\n",
|
||||
" ```python\n",
|
||||
" query_embedding = model.encode(\n",
|
||||
" \"How do I use prompts with this model?\",\n",
|
||||
" prompt_name=\"Retrieval-query\"\n",
|
||||
" )\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"* **For Documents:** Use `prompt_name=\"Retrieval-document\"`. To further improve document embeddings, you can also include a title by using the `prompt` argument directly:<br>\n",
|
||||
" * **With a title:**<br>\n",
|
||||
" ```python\n",
|
||||
" doc_embedding = model.encode(\n",
|
||||
" \"The document text...\",\n",
|
||||
" prompt=\"title: Using Prompts in RAG | text: \"\n",
|
||||
" )\n",
|
||||
" ```\n",
|
||||
" * **Without a title:**<br>\n",
|
||||
" ```python\n",
|
||||
" doc_embedding = model.encode(\n",
|
||||
" \"The document text...\",\n",
|
||||
" prompt=\"title: none | text: \"\n",
|
||||
" )\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"### Further Reading\n",
|
||||
"\n",
|
||||
"* For details on all available EmbeddingGemma prompts, see the [model card](http://ai.google.dev/gemma/docs/embeddinggemma/model_card#prompt_instructions).\n",
|
||||
"* For general information on prompt templates, see the [Sentence Transformer documentation](https://sbert.net/examples/sentence_transformer/applications/computing-embeddings/README.html#prompt-templates).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"id": "Y5hVNF3F-qZ7"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Available tasks:\n",
|
||||
" query: \"task: search result | query: \"\n",
|
||||
" document: \"title: none | text: \"\n",
|
||||
" BitextMining: \"task: search result | query: \"\n",
|
||||
" Clustering: \"task: clustering | query: \"\n",
|
||||
" Classification: \"task: classification | query: \"\n",
|
||||
" InstructionRetrieval: \"task: code retrieval | query: \"\n",
|
||||
" MultilabelClassification: \"task: classification | query: \"\n",
|
||||
" PairClassification: \"task: sentence similarity | query: \"\n",
|
||||
" Reranking: \"task: search result | query: \"\n",
|
||||
" Retrieval: \"task: search result | query: \"\n",
|
||||
" Retrieval-query: \"task: search result | query: \"\n",
|
||||
" Retrieval-document: \"title: none | text: \"\n",
|
||||
" STS: \"task: sentence similarity | query: \"\n",
|
||||
" Summarization: \"task: summarization | query: \"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Available tasks:\")\n",
|
||||
"for name, prefix in model.prompts.items():\n",
|
||||
" print(f\" {name}: \\\"{prefix}\\\"\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eIfWZ_z3xDZq"
|
||||
},
|
||||
"source": [
|
||||
"## Simple RAG example\n",
|
||||
"\n",
|
||||
"Retrieval is the task of finding the most relevant pieces of information from a large collection (a database, a set of documents, a website) based on the meaning of a query, not just keywords.\n",
|
||||
"\n",
|
||||
"Imagine you work for a company, and you need to find information from the internal employee handbook, which is stored as a collection of hundreds of documents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "fbaiy-CXRAs7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Corp knowledge base\n",
|
||||
"corp_knowledge_base = [\n",
|
||||
" {\n",
|
||||
" \"category\": \"HR & Leave Policies\",\n",
|
||||
" \"documents\": [\n",
|
||||
" {\n",
|
||||
" \"title\": \"Procedure for Unscheduled Absence\",\n",
|
||||
" \"content\": \"In the event of an illness or emergency preventing you from working, please notify both your direct manager and the HR department via email by 9:30 AM JST. The subject line should be 'Sick Leave - [Your Name]'. If the absence extends beyond two consecutive days, a doctor's certificate (診断書) will be required upon your return.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"title\": \"Annual Leave Policy\",\n",
|
||||
" \"content\": \"Full-time employees are granted 10 days of annual paid leave in their first year. This leave is granted six months after the date of joining and increases each year based on length of service. For example, an employee in their third year of service is entitled to 14 days per year. For a detailed breakdown, please refer to the attached 'Annual Leave Accrual Table'.\"\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"category\": \"IT & Security\",\n",
|
||||
" \"documents\": [\n",
|
||||
" {\n",
|
||||
" \"title\": \"Account Password Management\",\n",
|
||||
" \"content\": \"If you have forgotten your password or your account is locked, please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions. For security reasons, the IT Help Desk cannot reset passwords over the phone or email. If you have not set up your security questions, please visit the IT support desk on the 12th floor of the Shibuya office with your employee ID card.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"title\": \"Software Procurement Process\",\n",
|
||||
" \"content\": \"All requests for new software must be submitted through the 'IT Service Desk' portal under the 'Software Request' category. Please include a business justification for the request. All software licenses require approval from your department head before procurement can begin. Please note that standard productivity software is pre-approved and does not require this process.\"\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"category\": \"Finance & Expenses\",\n",
|
||||
" \"documents\": [\n",
|
||||
" {\n",
|
||||
" \"title\": \"Expense Reimbursement Policy\",\n",
|
||||
" \"content\": \"To ensure timely processing, all expense claims for a given month must be submitted for approval no later than the 5th business day of the following month. For example, all expenses incurred in July must be submitted by the 5th business day of August. Submissions after this deadline may be processed in the next payment cycle.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"title\": \"Business Trip Expense Guidelines\",\n",
|
||||
" \"content\": \"Travel expenses for business trips will, as a rule, be reimbursed based on the actual cost of the most logical and economical route. Please submit a travel expense application in advance when using the Shinkansen or airplanes. Taxis are permitted only when public transportation is unavailable or when transporting heavy equipment. Receipts are mandatory.\"\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"category\": \"Office & Facilities\",\n",
|
||||
" \"documents\": [\n",
|
||||
" {\n",
|
||||
" \"title\": \"Conference Room Booking Instructions\",\n",
|
||||
" \"content\": \"All conference rooms in the Shibuya office can be reserved through your Calendar App. Create a new meeting invitation, add the attendees, and then use the 'Room Finder' feature to select an available room. Please be sure to select the correct floor. For meetings with more than 10 people, please book the 'Sakura' or 'Fuji' rooms on the 14th floor.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"title\": \"Mail and Delivery Policy\",\n",
|
||||
" \"content\": \"The company's mail services are intended for business-related correspondence only. For security and liability reasons, employees are kindly requested to refrain from having personal parcels or mail delivered to the Shibuya office address. The front desk will not be able to accept or hold personal deliveries.\"\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Fvecfoko--hL"
|
||||
},
|
||||
"source": [
|
||||
"And imagine you have a question like below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"id": "wN-WHf26J89m"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"How do I reset my password?\" # @param [\"How many days of annual paid leave do I get?\", \"How do I reset my password?\", \"What travel expenses can be reimbursed for a business trip?\", \"Can I receive personal packages at the office?\"] {type:\"string\", allow-input: true}\n",
|
||||
"\n",
|
||||
"# Define a minimum confidence threshold for a match to be considered valid\n",
|
||||
"similarity_threshold = 0.4 # @param {\"type\":\"slider\",\"min\":0,\"max\":1,\"step\":0.1}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2CSeSmF7OuMB"
|
||||
},
|
||||
"source": [
|
||||
"Search relevant document from the corporate knowledge base."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"id": "NngqWUxOyrLS"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Step 1: Finding the best category...\n",
|
||||
"['HR & Leave Policies', 'IT & Security', 'Finance & Expenses', 'Office & Facilities']\n",
|
||||
"tensor([[0.5063, 0.5937, 0.5076, 0.4221]])\n",
|
||||
" `-> ✅ Category Found: 'IT & Security' (Score: 0.59)\n",
|
||||
"\n",
|
||||
"Step 2: Finding the best document in that category...\n",
|
||||
"['Account Password Management', 'Software Procurement Process']\n",
|
||||
"tensor([[0.5829, 0.1531]])\n",
|
||||
" `-> ✅ Document Found: 'Account Password Management' (Score: 0.58)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# --- Helper Functions for Semantic Search ---\n",
|
||||
"\n",
|
||||
"def _calculate_best_match(similarities):\n",
|
||||
" print(similarities)\n",
|
||||
" if similarities is None or similarities.nelement() == 0:\n",
|
||||
" return None, 0.0\n",
|
||||
"\n",
|
||||
" # Find the index and value of the highest score\n",
|
||||
" best_index = similarities.argmax().item()\n",
|
||||
" best_score = similarities[0, best_index].item()\n",
|
||||
"\n",
|
||||
" return best_index, best_score\n",
|
||||
"\n",
|
||||
"def find_best_category(model, query, candidates):\n",
|
||||
" \"\"\"\n",
|
||||
" Finds the most relevant category from a list of candidates.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" model: The SentenceTransformer model.\n",
|
||||
" query: The user's query string.\n",
|
||||
" candidates: A list of category name strings.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" A tuple containing the index of the best category and its similarity score.\n",
|
||||
" \"\"\"\n",
|
||||
" if not candidates:\n",
|
||||
" return None, 0.0\n",
|
||||
"\n",
|
||||
" # Encode the query and candidate categories for classification\n",
|
||||
" query_embedding = model.encode(query, prompt_name=\"Classification\")\n",
|
||||
" candidate_embeddings = model.encode(candidates, prompt_name=\"Classification\")\n",
|
||||
"\n",
|
||||
" print(candidates)\n",
|
||||
" return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))\n",
|
||||
"\n",
|
||||
"def find_best_doc(model, query, candidates):\n",
|
||||
" \"\"\"\n",
|
||||
" Finds the most relevant document from a list of candidates.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" model: The SentenceTransformer model.\n",
|
||||
" query: The user's query string.\n",
|
||||
" candidates: A list of document dictionaries, each with 'title' and 'content'.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" A tuple containing the index of the best document and its similarity score.\n",
|
||||
" \"\"\"\n",
|
||||
" if not candidates:\n",
|
||||
" return None, 0.0\n",
|
||||
"\n",
|
||||
" # Encode the query for retrieval\n",
|
||||
" query_embedding = model.encode(query, prompt_name=\"Retrieval-query\")\n",
|
||||
"\n",
|
||||
" # Encode the document for similarity check\n",
|
||||
" doc_texts = [\n",
|
||||
" f\"title: {doc.get('title', 'none')} | text: {doc.get('content', '')}\"\n",
|
||||
" for doc in candidates\n",
|
||||
" ]\n",
|
||||
" candidate_embeddings = model.encode(doc_texts)\n",
|
||||
"\n",
|
||||
" print([doc['title'] for doc in candidates])\n",
|
||||
"\n",
|
||||
" # Calculate cosine similarity\n",
|
||||
" return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))\n",
|
||||
"\n",
|
||||
"# --- Main Search Logic ---\n",
|
||||
"\n",
|
||||
"# In your application, `best_document` would result from a search.\n",
|
||||
"# We initialize it to None to ensure it always exists.\n",
|
||||
"best_document = None\n",
|
||||
"\n",
|
||||
"# 1. Find the most relevant category\n",
|
||||
"print(\"Step 1: Finding the best category...\")\n",
|
||||
"categories = [item[\"category\"] for item in corp_knowledge_base]\n",
|
||||
"best_category_index, category_score = find_best_category(\n",
|
||||
" model, question, categories\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Check if the category score meets the threshold\n",
|
||||
"if category_score < similarity_threshold:\n",
|
||||
" print(f\" `-> 🤷 No relevant category found. The highest score was only {category_score:.2f}.\")\n",
|
||||
"else:\n",
|
||||
" best_category = corp_knowledge_base[best_category_index]\n",
|
||||
" print(f\" `-> ✅ Category Found: '{best_category['category']}' (Score: {category_score:.2f})\")\n",
|
||||
"\n",
|
||||
" # 2. Find the most relevant document ONLY if a good category was found\n",
|
||||
" print(\"\\nStep 2: Finding the best document in that category...\")\n",
|
||||
" best_document_index, document_score = find_best_doc(\n",
|
||||
" model, question, best_category[\"documents\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Check if the document score meets the threshold\n",
|
||||
" if document_score < similarity_threshold:\n",
|
||||
" print(f\" `-> 🤷 No relevant document found. The highest score was only {document_score:.2f}.\")\n",
|
||||
" else:\n",
|
||||
" best_document = best_category[\"documents\"][best_document_index]\n",
|
||||
" # 3. Display the final successful result\n",
|
||||
" print(f\" `-> ✅ Document Found: '{best_document['title']}' (Score: {document_score:.2f})\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "zK9T5rRGAMDw"
|
||||
},
|
||||
"source": [
|
||||
"Next, generate the answer with the retrieved context"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"id": "FrwKySpMASpt"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Question🙋♂️: How do I reset my password?\n",
|
||||
"Using document: Account Password Management\n",
|
||||
"Answer🤖: Please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import GenerationConfig\n",
|
||||
"MODEL_ID = \"google/gemma-4-E2B-it\"\n",
|
||||
"config = GenerationConfig.from_pretrained(MODEL_ID)\n",
|
||||
"config.max_new_tokens = 512\n",
|
||||
"\n",
|
||||
"qa_prompt_template = \"\"\"Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write \"I don't know.\"\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"CONTEXT:\n",
|
||||
"{context}\n",
|
||||
"---\n",
|
||||
"QUESTION:\n",
|
||||
"{question}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# First, check if a valid document was found before proceeding.\n",
|
||||
"if best_document and \"content\" in best_document:\n",
|
||||
" # If the document exists and has a \"content\" key, generate the answer.\n",
|
||||
" context = best_document[\"content\"]\n",
|
||||
"\n",
|
||||
" prompt = qa_prompt_template.format(context=context, question=question)\n",
|
||||
"\n",
|
||||
" messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [{\"type\": \"text\", \"text\": prompt}],\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" print(\"Question🙋♂️: \" + question)\n",
|
||||
" # This part assumes your pipeline and response parsing logic are correct\n",
|
||||
" answer = pipeline(messages, generation_config=config)[0][\"generated_text\"][1][\"content\"]\n",
|
||||
" print(\"Using document: \" + best_document[\"title\"])\n",
|
||||
" print(\"Answer🤖: \" + answer)\n",
|
||||
"\n",
|
||||
"else:\n",
|
||||
" # If best_document is None or doesn't have content, give a direct response.\n",
|
||||
" print(\"Question🙋♂️: \" + question)\n",
|
||||
" print(\"Answer🤖: I'm sorry, I could not find a relevant document to answer that question.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "h4J4pFA3IK1d"
|
||||
},
|
||||
"source": [
|
||||
"## Summary and next steps\n",
|
||||
"\n",
|
||||
"You have now learned how to build a practical RAG system with EmbeddingGemma.\n",
|
||||
"\n",
|
||||
"Explore what more you can do with EmbeddingGemma:\n",
|
||||
"\n",
|
||||
"* [Generate embeddings with Sentence Transformers](https://ai.google.dev/gemma/docs/embeddinggemma/inference-embeddinggemma-with-sentence-transformers)\n",
|
||||
"* [Fine-tune EmbeddingGemma](https://ai.google.dev/gemma/docs/embeddinggemma/fine-tuning-embeddinggemma-with-sentence-transformers)\n",
|
||||
"* [Mood Palette Generator](https://huggingface.co/spaces/google/mood-palette), an interactive application using EmbeddingGemma"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"name": "RAG_with_EmbeddingGemma.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
# Gemma
|
||||
|
||||
[](https://github.com/google-deepmind/gemma/actions/workflows/pytest_and_autopublish.yml)
|
||||
[](https://badge.fury.io/py/gemma)
|
||||
[](https://gemma-llm.readthedocs.io/en/latest/?badge=latest)
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language
|
||||
Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini
|
||||
research and technology.
|
||||
|
||||
This repository contains the implementation of the
|
||||
[`gemma`](https://pypi.org/project/gemma/) PyPI package. A
|
||||
[JAX](https://github.com/jax-ml/jax) library to use and fine-tune Gemma.
|
||||
|
||||
For examples and use cases, see our
|
||||
[documentation](https://gemma-llm.readthedocs.io/). Please
|
||||
report issues and feedback in
|
||||
[our GitHub](https://github.com/google-deepmind/gemma/issues).
|
||||
|
||||
### Installation
|
||||
|
||||
1. Install JAX for CPU, GPU or TPU. Follow the instructions on
|
||||
[the JAX website](https://jax.readthedocs.io/en/latest/installation.html).
|
||||
1. Run
|
||||
|
||||
```sh
|
||||
pip install gemma
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
Here is a minimal example to have a multi-turn, multi-modal conversation with
|
||||
Gemma:
|
||||
|
||||
```python
|
||||
from gemma import gm
|
||||
|
||||
# Model and parameters (Gemma 4)
|
||||
model = gm.nn.Gemma4_E4B()
|
||||
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT)
|
||||
|
||||
# Example of multi-turn conversation
|
||||
sampler = gm.text.ChatSampler(
|
||||
model=model,
|
||||
params=params,
|
||||
multi_turn=True,
|
||||
)
|
||||
|
||||
prompt = """Which of the 2 images do you prefer ?
|
||||
|
||||
Image 1: <|image|>
|
||||
Image 2: <|image|>
|
||||
|
||||
Write your answer as a poem."""
|
||||
out0 = sampler.chat(prompt, images=[image1, image2])
|
||||
|
||||
out1 = sampler.chat('What about the other image ?')
|
||||
```
|
||||
|
||||
The same `ChatSampler` API works with all Gemma versions (2, 3, 3n, 4).
|
||||
|
||||
Our documentation contains various Colabs and tutorials, including:
|
||||
|
||||
* [Sampling](https://gemma-llm.readthedocs.io/en/latest/colab_sampling.html)
|
||||
* [Multi-modal](https://gemma-llm.readthedocs.io/en/latest/colab_multimodal.html)
|
||||
* [Fine-tuning](https://gemma-llm.readthedocs.io/en/latest/colab_finetuning.html)
|
||||
* [LoRA](https://gemma-llm.readthedocs.io/en/latest/colab_lora_sampling.html)
|
||||
* ...
|
||||
|
||||
Additionally, our
|
||||
[examples/](https://github.com/google-deepmind/gemma/tree/main/examples) folder
|
||||
contain additional scripts to fine-tune and sample with Gemma.
|
||||
|
||||
### Learn more about Gemma
|
||||
|
||||
* To use this library: [Gemma documentation](https://gemma-llm.readthedocs.io/)
|
||||
* Technical reports for metrics and model capabilities:
|
||||
* [Gemma 1](https://goo.gle/GemmaReport)
|
||||
* [Gemma 2](https://goo.gle/gemma2report)
|
||||
* [Gemma 3](https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf)
|
||||
* Gemma 4 (Coming soon)
|
||||
* Other Gemma implementations and doc on the
|
||||
[Gemma ecosystem](https://ai.google.dev/gemma/docs)
|
||||
|
||||
### Downloading the models
|
||||
|
||||
To download the model weights. See
|
||||
[our documentation](https://gemma-llm.readthedocs.io/en/latest/checkpoints.html).
|
||||
|
||||
### System Requirements
|
||||
|
||||
Gemma can run on a CPU, GPU and TPU. For GPU, we recommend 8GB+ RAM on GPU for
|
||||
The 2B checkpoint and 24GB+ RAM on GPU are used for the 7B checkpoint.
|
||||
|
||||
### Contributing
|
||||
|
||||
We welcome contributions! Please read our [Contributing Guidelines](./CONTRIBUTING.md) before submitting a pull request.
|
||||
|
||||
*This is not an official Google product.*
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,568 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"metadata": {
|
||||
"id": "-KkvqLgjiIdD"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Tool Use\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/tool_use.ipynb)\n",
|
||||
"\n",
|
||||
"Demo to show how to use tool-use with Gemma library.\n",
|
||||
"\n",
|
||||
"Note: The Gemma 1, 2 and 3 models were not specifically trained for tool use. This is more a proof-of-concept than an officially supported feature."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "gcNRfVEnj4aq"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install -q gemma"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 2221,
|
||||
"status": "ok",
|
||||
"timestamp": 1749202985345,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "k1ZAgLg1j9NT"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Common imports\n",
|
||||
"import os\n",
|
||||
"import datetime\n",
|
||||
"\n",
|
||||
"# Gemma imports\n",
|
||||
"from gemma import gm"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 3
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "139lZszJj_CC"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 2,
|
||||
"status": "ok",
|
||||
"timestamp": 1749138071985,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "VtlWWLIYj_LJ"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 2
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "31JPZb5RkD_p"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Load the model and the params."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 39057,
|
||||
"status": "ok",
|
||||
"timestamp": 1749203024713,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "RsAo6k4_kEJS",
|
||||
"outputId": "e10afb5c-6c81-42e8-e590-a39ea4ef3bf7"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"model = gm.nn.Gemma3_4B()\n",
|
||||
"\n",
|
||||
"params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:2025-06-06 02:43:16,896:jax._src.xla_bridge:749: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'\n",
|
||||
"INFO:2025-06-06 02:43:16,897:jax._src.xla_bridge:749: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got \n",
|
||||
"INFO:2025-06-06 02:43:16,900:jax._src.xla_bridge:749: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'\n",
|
||||
"INFO:2025-06-06 02:43:16,901:jax._src.xla_bridge:749: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 4
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "p108c5yIlYH7"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Using existing tools\n",
|
||||
"\n",
|
||||
"If you're familiar with the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial, using tool-use differ in two ways:\n",
|
||||
"\n",
|
||||
"1. Using the `gm.text.ToolSampler` rather than the `gm.text.ChatSampler`.\n",
|
||||
"2. Passing the `tools=` you want to use to the sampler.\n",
|
||||
"\n",
|
||||
"For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 594
|
||||
},
|
||||
"executionInfo": {
|
||||
"elapsed": 50615,
|
||||
"status": "ok",
|
||||
"timestamp": 1749138791069,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "iRCV5h8BlVX6",
|
||||
"outputId": "b3b5d83d-8a8b-4982-fc8f-d409fb8b38a9"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"sampler = gm.text.ToolSampler(\n",
|
||||
" model=model,\n",
|
||||
" params=params,\n",
|
||||
" tools=[\n",
|
||||
" gm.tools.Calculator(),\n",
|
||||
" gm.tools.FileExplorer(),\n",
|
||||
" ],\n",
|
||||
" print_stream=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"output = sampler.chat('I have a serie `Sn+1 = cos(Sn) * 2`. Using the calculator, compute the steps 0-4 for S0 = 3')"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: I need to compute S0, S1, S2, S3, and S4 using the given recursive formula Sn+1 = cos(Sn) * 2 and S0 = 3.\n",
|
||||
"Let's start with S0 = 3.\n",
|
||||
"S1 = cos(S0) * 2 = cos(3) * 2\n",
|
||||
"S2 = cos(S1) * 2 = cos(cos(3) * 2) * 2\n",
|
||||
"S3 = cos(S2) * 2 = cos(cos(cos(3) * 2) * 2) * 2\n",
|
||||
"S4 = cos(S3) * 2 = cos(cos(cos(cos(3) * 2) * 2)) * 2\n",
|
||||
"\n",
|
||||
"I will use the calculator to compute these values.\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(3) * 2\"}\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: -1.9799849932008908]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S1 = cos(S0) * 2 = cos(3) * 2 = -1.9799849932008908 * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"-1.9799849932008908 * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: -3.9599699864017817]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S2 = cos(S1) * 2 = cos(-3.9599699864017817) * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-3.9599699864017817) * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: -1.3668134299076982]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S3 = cos(S2) * 2 = cos(-1.3668134299076982) * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-1.3668134299076982) * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: 0.4051424976130353]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S4 = cos(S3) * 2 = cos(0.4051424976130353) * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(0.4051424976130353) * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: 1.8380924822033438]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The values are: S0 = 3, S1 = -3.9599699864017817, S2 = -1.3668134299076982, S3 = 0.4051424976130353, S4 = 1.8380924822033438"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 10
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "FAI54F-Blkan"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Note: Only the final model answer is returned. You can access the conversation history, including all intermediates tool calls and output through `sampler.turns` property."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "D0_IIS1Nlfuw"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Creating your own tool\n",
|
||||
"\n",
|
||||
"To create your own tool, you can inherit from the `gm.tools.Tool` class. You should provide:\n",
|
||||
"\n",
|
||||
"* A description & example, so the model knows how to use your tool\n",
|
||||
"* Implement the `call` method. The `call` function can take arbitrary `**kwargs`, but the name of the args should match the ones defined in `tool_kwargs` and `tool_kwargs_doc`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 55,
|
||||
"status": "ok",
|
||||
"timestamp": 1749203934196,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "XqmQcfdI0oEl"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"class DateTime(gm.tools.Tool):\n",
|
||||
" \"\"\"Tool to access the current date.\"\"\"\n",
|
||||
"\n",
|
||||
" DESCRIPTION = 'Access the current date, time,...'\n",
|
||||
" EXAMPLE = gm.tools.Example(\n",
|
||||
" query='Which day of the week are we today ?',\n",
|
||||
" thought='The `datetime.strptime` uses %a for day of the week',\n",
|
||||
" tool_kwargs={'format': '%a'},\n",
|
||||
" tool_kwargs_doc={'format': '<ANY datetime.strptime expression>'},\n",
|
||||
" result='Sat',\n",
|
||||
" answer='Today is Saturday.',\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def call(self, format: str) -> str:\n",
|
||||
" dt = datetime.datetime.now()\n",
|
||||
" return dt.strftime(format)\n"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 7
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "sSxYhXPuuXYp"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The tool can then be used in the sampler:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 118
|
||||
},
|
||||
"executionInfo": {
|
||||
"elapsed": 2156,
|
||||
"status": "ok",
|
||||
"timestamp": 1749204833094,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "9S8xB2B-0cbW",
|
||||
"outputId": "fccc0e89-e922-4184-8b77-800041cdd77e"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"sampler = gm.text.ToolSampler(\n",
|
||||
" model=model,\n",
|
||||
" params=params,\n",
|
||||
" tools=[\n",
|
||||
" DateTime(),\n",
|
||||
" ],\n",
|
||||
" print_stream=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"output = sampler.chat('Which date are we today ?')"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: I need to get the current date.\n",
|
||||
"{\"tool_name\": \"datetime\", \"format\": \"%Y-%m-%d\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: 2025-06-06]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Today is June 6th, 2025."
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 9
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "esIpCjhxzHmf"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.\n",
|
||||
"* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"last_runtime": {},
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example config for finetuning Gemma for a classification task.
|
||||
|
||||
* Input: A text to classify.
|
||||
* Output: A classification label. The pre-trained Gemma model is trained to
|
||||
predict one world among 256.000. Here, we're finetuning to predict only 2
|
||||
tokens among the 256.000 available.
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/classification.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Get the default hyperparameter configuration."""
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(training=True),
|
||||
# Model definition
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.sentence",
|
||||
return_last_only=True,
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.label",
|
||||
),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-4),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(training=False),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(training: bool) -> kd.data.Pipeline:
|
||||
# Dict key names from the dataset
|
||||
_INPUT_FIELD = "sentence" # pylint: disable=invalid-name
|
||||
_LABEL_FIELD = "label" # pylint: disable=invalid-name
|
||||
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="glue/cola",
|
||||
split="train" if training else "validation",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=8,
|
||||
transforms=[
|
||||
# Process the input text
|
||||
# TFDS datasets returns `bytes`, so convert them to `str`
|
||||
gm.data.DecodeBytes(key=_INPUT_FIELD),
|
||||
gm.data.FormatText(
|
||||
key=_INPUT_FIELD,
|
||||
template="""<start_of_turn>user
|
||||
Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No.
|
||||
Sentence: {text}<end_of_turn>
|
||||
<start_of_turn>model""",
|
||||
),
|
||||
gm.data.Tokenize(
|
||||
key=_INPUT_FIELD,
|
||||
tokenizer=tokenizer,
|
||||
add_bos=True,
|
||||
),
|
||||
gm.data.Pad(
|
||||
key=_INPUT_FIELD,
|
||||
max_length=128,
|
||||
),
|
||||
# Process the label
|
||||
gm.data.MapInts(
|
||||
key=_LABEL_FIELD,
|
||||
# Rather than predicting the token 0 and 1, we are using the
|
||||
# token 1294 and 3553 which respectivelly correspond to "No" and
|
||||
# "Yes". We do this because those token already contain semantic
|
||||
# information, so even zero-shot prediction without any
|
||||
# finetuning has better than random performances.
|
||||
old_to_new={
|
||||
0: 1294, # Token -> "No"
|
||||
1: 3553, # Token -> "Yes"
|
||||
},
|
||||
),
|
||||
kd.data.Rearrange(
|
||||
key=_LABEL_FIELD,
|
||||
pattern="... -> ... 1", # For shape compatibility with the loss.
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""DPO Example.
|
||||
|
||||
DPO works by running two answers (one prefered and one rejected) into both
|
||||
the reference model and the model to finetune. Then the DPO loss is used to
|
||||
increase the likelihood of generating the preferred answer.
|
||||
|
||||
Implementation wise, this is done by:
|
||||
|
||||
* Wrapping the model inside a `gm.nn.AnchoredPolicy` (which runs both the
|
||||
model and the reference frozen model)
|
||||
* Using the `gm.ckpts.AnchoredPolicyLoader` to restore the weights, so the
|
||||
weights are correctly mapped to inside `gm.nn.AnchoredPolicy`.
|
||||
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/dpo.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Get the default hyperparameter configuration."""
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(training=True),
|
||||
# Model definition
|
||||
model=gm.nn.AnchoredPolicy(
|
||||
policy=gm.nn.Gemma3_4B(tokens="batch.tokens", text_only=True),
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.AnchoredPolicyLoader(
|
||||
policy=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"dpo": gm.losses.DpoLoss(
|
||||
tokens="batch.targets",
|
||||
sequence_mask="batch.mask",
|
||||
policy_logits="preds.policy.logits",
|
||||
anchor_logits="preds.anchor.logits",
|
||||
),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-4),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
# "test": kd.evals.Evaluator(
|
||||
# run=kd.evals.EveryNSteps(1000),
|
||||
# ds=_make_dataset(training=False),
|
||||
# ),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(training: bool) -> kd.data.Pipeline:
|
||||
# TODO(epot): !!!!
|
||||
max_length = 512
|
||||
batch_size = 16
|
||||
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.HuggingFace(
|
||||
path="argilla/distilabel-math-preference-dpo",
|
||||
split="train",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=batch_size,
|
||||
transforms=[
|
||||
# Only keep the fields we need.
|
||||
kd.data.Elements(
|
||||
keep=["instruction", "chosen_response", "rejected_response"]
|
||||
),
|
||||
# Create the model inputs and loss mask.
|
||||
gm.data.ContrastiveTask(
|
||||
in_prompt="instruction",
|
||||
in_chosen="chosen_response",
|
||||
in_rejected="rejected_response",
|
||||
out_tokens="tokens",
|
||||
out_targets="targets",
|
||||
out_mask="mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=max_length,
|
||||
# TODO(epot): Run stats (how many examples are we dropping?)
|
||||
truncate=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,154 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example of Gemma finetuning using LoRA.
|
||||
|
||||
This example is based on the `seq2seq.py` example. See the
|
||||
docstring of that file for more details.
|
||||
|
||||
The changes to use LoRA are:
|
||||
|
||||
* `model`: Use `gm.nn.LoRA()` wrapper to add `LoRA` adapters to the
|
||||
model.
|
||||
* `init_transform`: Use `gm.ckpts.SkipLoRA()` wrapper to only restore the
|
||||
non-LoRA weights.
|
||||
* `optimizer`: Use `kd.optim.partial_updates` wrapper to only train the LoRA
|
||||
weights.
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/lora.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
batch_size = 16
|
||||
max_length = 512
|
||||
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(
|
||||
training=True,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
# Model definition
|
||||
model=gm.nn.LoRA(
|
||||
rank=4,
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.input",
|
||||
# TODO(epot): At the moment, LoRA fine-tuning with multimodal
|
||||
# is not supported. Willbe fixed soon.
|
||||
text_only=True,
|
||||
),
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
# Use `SkipLoRA` as the original checkpoint does not contain the LoRA
|
||||
# weights.
|
||||
init_transform=gm.ckpts.SkipLoRA(
|
||||
wrapped=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
)
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.target",
|
||||
mask="batch.loss_mask",
|
||||
),
|
||||
},
|
||||
# TODO(epot): Add Gradient accumenlation.
|
||||
optimizer=kd.optim.partial_updates(
|
||||
optax.adafactor(learning_rate=0.005),
|
||||
# We only optimize the LoRA weights. The rest of the model is frozen.
|
||||
mask=kd.optim.select("lora"),
|
||||
),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(
|
||||
training=False,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
),
|
||||
# The sampler evaluator run inference on a few prompts from the
|
||||
# test set.
|
||||
"sampling": gm.evals.SamplerEvaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
max_new_tokens=150, # Sampling parameters
|
||||
num_batches=1, # Only predict a single example (batch_size=None)
|
||||
ds=_make_dataset(training=False, sampling=True),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
training: bool,
|
||||
sampling: bool = False,
|
||||
batch_size: int | None = None,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="mtnt/en-fr",
|
||||
split="train" if training else "test",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=None if sampling else batch_size,
|
||||
num_workers=4,
|
||||
transforms=[
|
||||
# Create the model inputs/targets/loss_mask.
|
||||
gm.data.Seq2SeqTask(
|
||||
# Select which field from the dataset to use.
|
||||
# https://www.tensorflow.org/datasets/catalog/mtnt
|
||||
in_prompt="src",
|
||||
in_response="dst",
|
||||
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
|
||||
out_input="input",
|
||||
out_target="target",
|
||||
out_target_mask="loss_mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=None if sampling else max_length,
|
||||
# In this dataset, ~1% of examples are longer than 512 tokens.
|
||||
truncate=True,
|
||||
sampling=sampling,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example of Gemma finetuning for an image captioning task.
|
||||
|
||||
Example:
|
||||
|
||||
Prompt:
|
||||
|
||||
```
|
||||
<start_of_turn>user
|
||||
<start_of_image><end_of_turn>
|
||||
<start_of_turn>model
|
||||
```
|
||||
|
||||
Target:
|
||||
|
||||
```
|
||||
A diagram showing a circuit with a battery, lamp, and switch.<end_of_turn>
|
||||
```
|
||||
|
||||
Here, the prompt only contains the `<start_of_image>` to indicate an image
|
||||
is inserted.
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/multimodal.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
import jax.numpy as jnp
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
batch_size = 32
|
||||
max_length = 200
|
||||
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(
|
||||
training=True,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
# Model definition
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.input",
|
||||
images="batch.image",
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.target",
|
||||
mask="batch.loss_mask",
|
||||
),
|
||||
},
|
||||
train_summaries={
|
||||
"image": kd.summaries.ShowImages(images="batch.image", num_images=5),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-3),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(
|
||||
training=False,
|
||||
batch_size=4,
|
||||
max_length=max_length,
|
||||
),
|
||||
),
|
||||
# The sampler evaluator run inference on a few prompts from the
|
||||
# test set.
|
||||
"sampling": gm.evals.SamplerEvaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
max_new_tokens=50, # Sampling parameters
|
||||
num_batches=3,
|
||||
ds=_make_dataset(training=False, sampling=True),
|
||||
summaries={
|
||||
"image": kd.summaries.ShowImages(
|
||||
images="batch.image", num_images=5
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
training: bool,
|
||||
sampling: bool = False,
|
||||
batch_size: int | None = None,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="ai2dcaption",
|
||||
split="llava_15" if training else "test",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=None if sampling else batch_size,
|
||||
num_workers=4,
|
||||
transforms=[
|
||||
# Only keep the fields we need.See fields at:
|
||||
# https://www.tensorflow.org/datasets/catalog/ai2dcaption
|
||||
kd.data.Elements(keep=["image", "caption"]),
|
||||
# Create a new constant field
|
||||
kd.data.AddConstants({"prompt": "<start_of_image>"}),
|
||||
# Create the model inputs/targets/loss_mask.
|
||||
gm.data.Seq2SeqTask(
|
||||
# Select which field from the dataset to use.
|
||||
in_prompt="prompt",
|
||||
in_response="caption",
|
||||
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
|
||||
out_input="input",
|
||||
out_target="target",
|
||||
out_target_mask="loss_mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=None if sampling else max_length,
|
||||
# In this dataset, ~1% of examples are longer than 512 tokens.
|
||||
truncate=True,
|
||||
sampling=sampling,
|
||||
),
|
||||
kd.data.py.Resize(key="image", size=(800, 800)),
|
||||
# TODO(epot): Make the `num_images` dimension optional
|
||||
kd.data.Rearrange(key="image", pattern="... h w c -> ... 1 h w c"),
|
||||
kd.data.Cast(key="image", dtype=jnp.uint8),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example of Gemma finetuning for a prompt -> response task.
|
||||
|
||||
This is a fork of the seq2seq example, but with sharding.
|
||||
The only difference is the `sharding=kd.sharding.ShardingStrategy()`
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/sharding.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
batch_size = 16
|
||||
max_length = 512
|
||||
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(
|
||||
training=True,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
# Model definition
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.input",
|
||||
),
|
||||
sharding=kd.sharding.ShardingStrategy(
|
||||
params=kd.sharding.FSDPSharding(),
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.target",
|
||||
mask="batch.loss_mask",
|
||||
),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-3),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(
|
||||
training=False,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
),
|
||||
# The sampler evaluator run inference on a few prompts from the
|
||||
# test set.
|
||||
"sampling": gm.evals.SamplerEvaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
max_new_tokens=50, # Sampling parameters
|
||||
num_batches=1, # Only predict a single example (batch_size=None)
|
||||
ds=_make_dataset(training=False, sampling=True),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
training: bool,
|
||||
sampling: bool = False,
|
||||
batch_size: int | None = None,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="mtnt/en-fr",
|
||||
split="train" if training else "test",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=None if sampling else batch_size,
|
||||
num_workers=4,
|
||||
transforms=[
|
||||
# Create the model inputs/targets/loss_mask.
|
||||
gm.data.Seq2SeqTask(
|
||||
# Select which field from the dataset to use.
|
||||
# https://www.tensorflow.org/datasets/catalog/mtnt
|
||||
in_prompt="src",
|
||||
in_response="dst",
|
||||
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
|
||||
out_input="input",
|
||||
out_target="target",
|
||||
out_target_mask="loss_mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=None if sampling else max_length,
|
||||
# In this dataset, ~1% of examples are longer than 512 tokens.
|
||||
truncate=True,
|
||||
sampling=sampling,
|
||||
),
|
||||
],
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,250 @@
|
||||
# Gemma.cpp API Server
|
||||
|
||||
This is an HTTP API server for gemma.cpp that implements the Google API protocol, allowing you to interact with Gemma models through REST API endpoints compatible with the Google API format.
|
||||
|
||||
## Features
|
||||
|
||||
- **API-compatible**: Implements Google API endpoints
|
||||
- **Unified client/server**: Single codebase supports both local and public API modes
|
||||
- **Text generation**: Support for `generateContent` endpoint
|
||||
- **Streaming support**: Server-Sent Events (SSE) for `streamGenerateContent`
|
||||
- **Model management**: Support for `/v1beta/models` endpoint
|
||||
- **Session management**: Maintains conversation context with KV cache
|
||||
- **JSON responses**: All responses in Google API format
|
||||
- **Error handling**: Proper HTTP status codes and error messages
|
||||
|
||||
## Building
|
||||
|
||||
The API server is built alongside the main gemma.cpp project:
|
||||
|
||||
```bash
|
||||
# Configure the build
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
# Build the API server and client
|
||||
cmake --build build --target gemma_api_server gemma_api_client -j 8
|
||||
```
|
||||
|
||||
The binaries will be created at:
|
||||
- `build/gemma_api_server` - Local API server
|
||||
- `build/gemma_api_client` - Unified client for both local and public APIs
|
||||
|
||||
## Usage
|
||||
|
||||
### Starting the Local API Server
|
||||
|
||||
```bash
|
||||
./build/gemma_api_server \
|
||||
--tokenizer path/to/tokenizer.spm \
|
||||
--weights path/to/model.sbs \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
**Required arguments:**
|
||||
- `--tokenizer`: Path to the tokenizer file (`.spm`)
|
||||
- `--weights`: Path to the model weights file (`.sbs`)
|
||||
|
||||
**Optional arguments:**
|
||||
- `--port`: Port to listen on (default: 8080)
|
||||
- `--model`: Model name for API endpoints (default: gemma3-4b)
|
||||
|
||||
### Using the Unified Client
|
||||
|
||||
#### With Local Server
|
||||
```bash
|
||||
# Interactive chat with local server
|
||||
./build/gemma_api_client --interactive 1 --host localhost --port 8080
|
||||
|
||||
# Single prompt with local server
|
||||
./build/gemma_api_client --prompt "Hello, how are you?"
|
||||
```
|
||||
|
||||
#### With Public Google API
|
||||
```bash
|
||||
# Set API key and use public API
|
||||
export GOOGLE_API_KEY="your-api-key-here"
|
||||
./build/gemma_api_client --interactive 1
|
||||
|
||||
# Or pass API key directly
|
||||
./build/gemma_api_client --api_key "your-api-key" --interactive 1
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
The server implements Google API endpoints:
|
||||
|
||||
### 1. Generate Content - `POST /v1beta/models/gemma3-4b:generateContent`
|
||||
|
||||
Generate a response for given content (non-streaming).
|
||||
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{"text": "Why is the sky blue?"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"generationConfig": {
|
||||
"temperature": 0.9,
|
||||
"topK": 1,
|
||||
"maxOutputTokens": 1024
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"text": "The sky appears blue because..."}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"promptFeedback": {
|
||||
"safetyRatings": []
|
||||
},
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 5,
|
||||
"candidatesTokenCount": 25,
|
||||
"totalTokenCount": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Stream Generate Content - `POST /v1beta/models/gemma3-4b:streamGenerateContent`
|
||||
|
||||
Generate a response with Server-Sent Events (SSE) streaming.
|
||||
|
||||
**Request:** Same as above
|
||||
|
||||
**Response:** Stream of SSE events:
|
||||
```
|
||||
data: {"candidates":[{"content":{"parts":[{"text":"The"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}}
|
||||
|
||||
data: {"candidates":[{"content":{"parts":[{"text":" sky"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
### 3. List Models - `GET /v1beta/models`
|
||||
|
||||
List available models.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "models/gemma3-4b",
|
||||
"displayName": "Gemma3 4B",
|
||||
"description": "Gemma3 4B model running locally"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Using curl with Local Server
|
||||
|
||||
```bash
|
||||
# Generate content (non-streaming)
|
||||
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [{"parts": [{"text": "Hello, how are you?"}]}],
|
||||
"generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024}
|
||||
}'
|
||||
|
||||
# Stream generate content (SSE)
|
||||
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:streamGenerateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [{"parts": [{"text": "Tell me a story"}]}],
|
||||
"generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024}
|
||||
}'
|
||||
|
||||
# List models
|
||||
curl http://localhost:8080/v1beta/models
|
||||
```
|
||||
|
||||
### Multi-turn Conversation with curl
|
||||
|
||||
```bash
|
||||
# First message
|
||||
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [
|
||||
{"parts": [{"text": "Hi, my name is Alice"}]}
|
||||
]
|
||||
}'
|
||||
|
||||
# Follow-up message with conversation history
|
||||
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [
|
||||
{"parts": [{"text": "Hi, my name is Alice"}]},
|
||||
{"parts": [{"text": "Hello Alice! Nice to meet you."}]},
|
||||
{"parts": [{"text": "What is my name?"}]}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Using Python
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# Generate content
|
||||
response = requests.post('http://localhost:8080/v1beta/models/gemma3-4b:generateContent',
|
||||
json={
|
||||
'contents': [{'parts': [{'text': 'Explain quantum computing in simple terms'}]}],
|
||||
'generationConfig': {
|
||||
'temperature': 0.9,
|
||||
'topK': 1,
|
||||
'maxOutputTokens': 1024
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
if 'candidates' in result and result['candidates']:
|
||||
text = result['candidates'][0]['content']['parts'][0]['text']
|
||||
print(text)
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The Google API supports various generation configuration options:
|
||||
|
||||
- **temperature**: Controls randomness (0.0 to 2.0, default: 1.0)
|
||||
- **topK**: Top-K sampling parameter (default: 1)
|
||||
- **maxOutputTokens**: Maximum number of tokens to generate (default: 8192)
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Unified Implementation**: Same codebase handles both local server and public API
|
||||
- **Session Management**: Maintains conversation context using KV cache
|
||||
- **Streaming Support**: Real-time token generation via Server-Sent Events
|
||||
- **Error Handling**: Comprehensive error responses and HTTP status codes
|
||||
- **Memory Efficient**: Optimized token processing and caching
|
||||
|
||||
## Compatibility
|
||||
|
||||
This implementation is compatible with:
|
||||
- Google API format and endpoints
|
||||
- Standard HTTP clients (curl, browsers, Python requests, etc.)
|
||||
- Server-Sent Events (SSE) for streaming responses
|
||||
- JSON request/response format
|
||||
@@ -0,0 +1,532 @@
|
||||
# gemma.cpp
|
||||
|
||||
gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
|
||||
foundation models from Google.
|
||||
|
||||
For additional information about Gemma, see
|
||||
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
|
||||
gemma.cpp specific artifacts, are
|
||||
[available on kaggle](https://www.kaggle.com/models/google/gemma-2).
|
||||
|
||||
## Who is this project for?
|
||||
|
||||
Modern LLM inference engines are sophisticated systems, often with bespoke
|
||||
capabilities extending beyond traditional neural network runtimes. With this
|
||||
comes opportunities for research and innovation through co-design of high level
|
||||
algorithms and low-level computation. However, there is a gap between
|
||||
deployment-oriented C++ inference runtimes, which are not designed for
|
||||
experimentation, and Python-centric ML research frameworks, which abstract away
|
||||
low-level computation through compilation.
|
||||
|
||||
gemma.cpp provides a minimalist implementation of Gemma-2, Gemma-3, and
|
||||
PaliGemma-2 models, focusing on simplicity and directness rather than full
|
||||
generality. This is inspired by vertically-integrated model implementations such
|
||||
as [ggml](https://github.com/ggerganov/ggml),
|
||||
[llama.c](https://github.com/karpathy/llama2.c), and
|
||||
[llama.rs](https://github.com/srush/llama2.rs).
|
||||
|
||||
gemma.cpp targets experimentation and research use cases. It is intended to be
|
||||
straightforward to embed in other projects with minimal dependencies and also
|
||||
easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC
|
||||
of supporting utilities). We use the [Google
|
||||
Highway](https://github.com/google/highway) Library to take advantage of
|
||||
portable SIMD for CPU inference.
|
||||
|
||||
For production-oriented edge deployments we recommend standard deployment
|
||||
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
|
||||
([all model variations here](https://www.kaggle.com/models/google/gemma)).
|
||||
|
||||
## Contributing
|
||||
|
||||
Community contributions large and small are welcome. See
|
||||
[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md)
|
||||
for additional notes contributing developers and [join the discord by following
|
||||
this invite link](https://discord.gg/H5jCBAWxAe). This project follows
|
||||
[Google's Open Source Community
|
||||
Guidelines](https://opensource.google.com/conduct/).
|
||||
|
||||
> [!NOTE] Active development is currently done on the `dev` branch. Please open
|
||||
> pull requests targeting `dev` branch instead of `main`, which is intended to
|
||||
> be more stable.
|
||||
|
||||
## What's inside?
|
||||
|
||||
- LLM
|
||||
|
||||
- CPU-only inference for: Gemma 2-3, PaliGemma 2.
|
||||
- Sampling with TopK and temperature.
|
||||
- Backward pass (VJP) and Adam optimizer for Gemma research.
|
||||
|
||||
- Optimizations
|
||||
|
||||
- Mixed-precision (fp8, bf16, fp32, fp64 bit) GEMM:
|
||||
- Designed for BF16 instructions, can efficiently emulate them.
|
||||
- Automatic runtime autotuning 7 parameters per matrix shape.
|
||||
- Weight compression integrated directly into GEMM:
|
||||
- Custom fp8 format with 2..3 mantissa bits; tensor scaling.
|
||||
- Also bf16, f32 and non-uniform 4-bit (NUQ); easy to add new formats.
|
||||
|
||||
- Infrastructure
|
||||
|
||||
- SIMD: single implementation via Highway. Chooses ISA at runtime.
|
||||
- Tensor parallelism: CCX-aware, multi-socket thread pool.
|
||||
- Disk I/O: memory map or parallel read (heuristic with user override).
|
||||
- Custom format with forward/backward-compatible metadata serialization.
|
||||
- Model conversion from Safetensors, not yet open sourced.
|
||||
- Portability: Linux, Windows/OS X supported. CMake/Bazel. 'Any' CPU.
|
||||
|
||||
- Frontends
|
||||
|
||||
- C++ APIs with streaming for single query and batched inference.
|
||||
- Basic interactive command-line app.
|
||||
- Basic Python bindings (pybind11).
|
||||
|
||||
## Quick Start
|
||||
|
||||
### System requirements
|
||||
|
||||
Before starting, you should have installed:
|
||||
|
||||
- [CMake](https://cmake.org/)
|
||||
- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at
|
||||
least C++17.
|
||||
- `tar` for extracting archives from Kaggle.
|
||||
|
||||
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
|
||||
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
|
||||
command line with
|
||||
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
|
||||
|
||||
```sh
|
||||
winget install --id Kitware.CMake
|
||||
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
|
||||
```
|
||||
|
||||
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
|
||||
|
||||
Visit the
|
||||
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
|
||||
and select `Model Variations |> Gemma C++`.
|
||||
|
||||
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
|
||||
weights are higher fidelity, while 8-bit switched floating point weights enable
|
||||
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
|
||||
|
||||
> [!NOTE] **Important**: We strongly recommend starting off with the
|
||||
> `gemma2-2b-it-sfp` model to get up and running.
|
||||
|
||||
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
|
||||
`ModelPrefix` function in `configs.cc`.
|
||||
|
||||
### Step 2: Extract Files
|
||||
|
||||
After filling out the consent form, the download should proceed to retrieve a
|
||||
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
|
||||
take a few minutes):
|
||||
|
||||
```
|
||||
tar -xf archive.tar.gz
|
||||
```
|
||||
|
||||
This should produce a file containing model weights such as `2b-it-sfp.sbs` and
|
||||
a tokenizer file (`tokenizer.spm`). You may want to move these files to a
|
||||
convenient directory location (e.g. the `build/` directory in this repo).
|
||||
|
||||
### Step 3: Build
|
||||
|
||||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||
runtime, create a build directory and generate the build files using `cmake`
|
||||
from the top-level project directory. Note if you previous ran `cmake` and are
|
||||
re-running with a different setting, be sure to delete all files in the `build/`
|
||||
directory with `rm -rf build/*`.
|
||||
|
||||
#### Unix-like Platforms
|
||||
```sh
|
||||
cmake -B build
|
||||
```
|
||||
|
||||
After running `cmake`, you can enter the `build/` directory and run `make` to
|
||||
build the `./gemma` executable:
|
||||
|
||||
```sh
|
||||
# Configure `build` directory
|
||||
cmake --preset make
|
||||
|
||||
# Build project using make
|
||||
cmake --build --preset make -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
Replace `[number of parallel threads to use]` with a number - the number of
|
||||
cores available on your system is a reasonable heuristic. For example, `make -j4
|
||||
gemma` will build using 4 threads. If the `nproc` command is available, you can
|
||||
use `make -j$(nproc) gemma` as a reasonable default for the number of threads.
|
||||
|
||||
If you aren't sure of the right value for the `-j` flag, you can simply run
|
||||
`make gemma` instead and it should still build the `./gemma` executable.
|
||||
|
||||
> [!NOTE]
|
||||
> On Windows Subsystem for Linux (WSL) users should set the number of
|
||||
> parallel threads to 1. Using a larger number may result in errors.
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the
|
||||
`build/` directory.
|
||||
|
||||
#### Windows
|
||||
|
||||
```sh
|
||||
# Configure `build` directory
|
||||
cmake --preset windows
|
||||
|
||||
# Build project using Visual Studio Build Tools
|
||||
cmake --build --preset windows -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the
|
||||
`build/` directory.
|
||||
|
||||
#### Bazel
|
||||
|
||||
```sh
|
||||
bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the
|
||||
`bazel-bin/` directory.
|
||||
|
||||
#### Make
|
||||
|
||||
If you prefer Makefiles, @jart has made one available here:
|
||||
|
||||
https://github.com/jart/gemma3/blob/main/Makefile
|
||||
|
||||
### Step 4: Run
|
||||
|
||||
You can now run `gemma` from inside the `build/` directory.
|
||||
|
||||
`gemma` has the following required arguments:
|
||||
|
||||
Argument | Description | Example value
|
||||
------------- | ---------------------------- | ---------------
|
||||
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||
|
||||
Example invocation for the following configuration:
|
||||
|
||||
- weights file `gemma2-2b-it-sfp.sbs` (Gemma2 2B instruction-tuned model,
|
||||
8-bit switched floating point).
|
||||
- Tokenizer file `tokenizer.spm` (can omit for single-format weights files
|
||||
created after 2025-05-06, or output by migrate_weights.cc).
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
|
||||
```
|
||||
|
||||
### PaliGemma Vision-Language Model
|
||||
|
||||
This repository includes a version of the PaliGemma 2 VLM
|
||||
([paper](https://arxiv.org/abs/2412.03555)). We provide a C++ implementation of
|
||||
the PaliGemma 2 model here.
|
||||
|
||||
To use the version of PaliGemma included in this repository, build the gemma
|
||||
binary as noted above in Step 3. Download the compressed weights and tokenizer
|
||||
from
|
||||
[Kaggle](https://www.kaggle.com/models/google/paligemma-2/gemmaCpp/paligemma2-3b-mix-224)
|
||||
and run the binary as follows:
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer paligemma_tokenizer.model \
|
||||
--weights paligemma2-3b-mix-224-sfp.sbs \
|
||||
--image_file paligemma/testdata/image.ppm
|
||||
```
|
||||
|
||||
Note that the image reading code is very basic to avoid depending on an image
|
||||
processing library for now. We currently only support reading binary PPMs (P6).
|
||||
So use a tool like `convert` to first convert your images into that format, e.g.
|
||||
|
||||
`convert image.jpeg -resize 224x224^ image.ppm`
|
||||
|
||||
(As the image will be resized for processing anyway, we can already resize at
|
||||
this stage for slightly faster loading.)
|
||||
|
||||
The interaction with the image (using the mix-224 checkpoint) may then look
|
||||
something like this:
|
||||
|
||||
```
|
||||
> Describe the image briefly
|
||||
A large building with two towers in the middle of a city.
|
||||
> What type of building is it?
|
||||
church
|
||||
> What color is the church?
|
||||
gray
|
||||
> caption image
|
||||
A large building with two towers stands tall on the water's edge. The building
|
||||
has a brown roof and a window on the side. A tree stands in front of the
|
||||
building, and a flag waves proudly from its top. The water is calm and blue,
|
||||
reflecting the sky above. A bridge crosses the water, and a red and white boat
|
||||
rests on its surface. The building has a window on the side, and a flag on top.
|
||||
A tall tree stands in front of the building, and a window on the building is
|
||||
visible from the water. The water is green, and the sky is blue.
|
||||
```
|
||||
|
||||
### Migrating to single-file format
|
||||
|
||||
There is now a new format for the weights file, which is a single file that
|
||||
allows to contain the tokenizer (and the model type) directly. A tool to migrate
|
||||
from the multi-file format to the single-file format is available.
|
||||
|
||||
```sh
|
||||
io/migrate_weights \
|
||||
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
|
||||
--output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||
```
|
||||
|
||||
After migration, you can omit the tokenizer argument like this:
|
||||
|
||||
```sh
|
||||
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
|
||||
```
|
||||
|
||||
### Troubleshooting and FAQs
|
||||
|
||||
**Problems building in Windows / Visual Studio**
|
||||
|
||||
Currently if you're using Windows, we recommend building in WSL (Windows
|
||||
Subsystem for Linux). We are exploring options to enable other build
|
||||
configurations, see issues for active discussion.
|
||||
|
||||
**Model does not respond to instructions and produces strange output**
|
||||
|
||||
A common issue is that you are using a pre-trained model, which is not
|
||||
instruction-tuned and thus does not respond to instructions. Make sure you are
|
||||
using an instruction-tuned model (`gemma2-2b-it-sfp`) and not a pre-trained
|
||||
model (any model with a `-pt` suffix).
|
||||
|
||||
**What sequence lengths are supported?**
|
||||
|
||||
See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3
|
||||
models larger than 1B, this is typically 32K but 128K would also work given
|
||||
enough RAM. Note that long sequences will be slow due to the quadratic cost of
|
||||
attention.
|
||||
|
||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||
|
||||
For PaliGemma 2 checkpoints, you can use python/convert_from_safetensors.py to
|
||||
convert from safetensors format (tested with building via bazel). For an adapter
|
||||
model, you will likely need to call merge_and_unload() to convert the adapter
|
||||
model to a single-file format before converting it.
|
||||
|
||||
Here is how to use it using a bazel build of the compression library assuming
|
||||
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
|
||||
|
||||
```sh
|
||||
bazel build //compression/python:compression
|
||||
BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression"
|
||||
python3 -c "import site; print(site.getsitepackages())"
|
||||
# Use your sites-packages file here:
|
||||
ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
|
||||
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
|
||||
```
|
||||
|
||||
**What are some easy ways to make the model run faster?**
|
||||
|
||||
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
|
||||
These are half the size of bf16 and thus use less memory bandwidth and cache
|
||||
space.
|
||||
2. Due to auto-tuning, the second and especially third query will be faster.
|
||||
3. If you're on a laptop, make sure power mode is set to maximize performance
|
||||
and saving mode is **off**. For most laptops, the power saving modes get
|
||||
activated automatically if the computer is not plugged in.
|
||||
4. Close other unused cpu-intensive applications.
|
||||
5. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||
cores get engaged.
|
||||
|
||||
We're also working on algorithmic and optimization approaches for faster
|
||||
inference, stay tuned.
|
||||
|
||||
## Usage
|
||||
|
||||
`gemma` has different usage modes, controlled by the verbosity flag.
|
||||
|
||||
All usage modes are currently interactive, triggering text generation upon
|
||||
newline input.
|
||||
|
||||
| Verbosity | Usage mode | Details |
|
||||
| --------------- | ---------- | --------------------------------------------- |
|
||||
| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. |
|
||||
| `--verbosity 1` | Default | Standard user-facing terminal UI. |
|
||||
| `--verbosity 2` | Detailed | Shows additional developer and debug info. |
|
||||
|
||||
### Interactive Terminal App
|
||||
|
||||
By default, verbosity is set to 1, bringing up a terminal-based interactive
|
||||
interface when `gemma` is invoked:
|
||||
|
||||
```sh
|
||||
$ ./gemma [...]
|
||||
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
|
||||
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
|
||||
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
|
||||
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
|
||||
__/ | | | | |
|
||||
|___/ |_| |_|
|
||||
|
||||
...
|
||||
|
||||
*Usage*
|
||||
Enter an instruction and press enter (%C reset conversation, %Q quits).
|
||||
|
||||
*Examples*
|
||||
- Write an email to grandma thanking her for the cookies.
|
||||
- What are some historical attractions to visit around Massachusetts?
|
||||
- Compute the nth fibonacci number in javascript.
|
||||
- Write a standup comedy bit about WebGPU programming.
|
||||
|
||||
> What are some outdoorsy places to visit around Boston?
|
||||
|
||||
[ Reading prompt ] .....................
|
||||
|
||||
|
||||
**Boston Harbor and Islands:**
|
||||
|
||||
* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history.
|
||||
* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline.
|
||||
* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective.
|
||||
* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum.
|
||||
|
||||
**Forest and Nature:**
|
||||
|
||||
* **Forest Park:** Hike through a scenic forest with diverse wildlife.
|
||||
* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting.
|
||||
* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape.
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
### Usage as a Command Line Tool
|
||||
|
||||
For using the `gemma` executable as a command line tool, it may be useful to
|
||||
create an alias for gemma.cpp with arguments fully specified:
|
||||
|
||||
```sh
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --verbosity 0"
|
||||
```
|
||||
|
||||
Replace the above paths with your own paths to the model and tokenizer paths
|
||||
from the download.
|
||||
|
||||
Here is an example of prompting `gemma` with a truncated input
|
||||
file (using a `gemma2b` alias like defined above):
|
||||
|
||||
```sh
|
||||
cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> CLI usage of gemma.cpp is experimental and should take context length
|
||||
> limitations into account.
|
||||
|
||||
The output of the above command should look like:
|
||||
|
||||
```sh
|
||||
[ Reading prompt ] [...]
|
||||
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
|
||||
|
||||
Let's break down the code:
|
||||
[...]
|
||||
```
|
||||
|
||||
### Incorporating gemma.cpp as a Library in your Project
|
||||
|
||||
The easiest way to incorporate gemma.cpp in your own project is to pull in
|
||||
gemma.cpp and dependencies using `FetchContent`. You can add the following to
|
||||
your CMakeLists.txt:
|
||||
|
||||
```
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific
|
||||
commit hash if you would like to pin the library version.
|
||||
|
||||
After your executable is defined (substitute your executable name for
|
||||
`[Executable Name]` below):
|
||||
|
||||
```
|
||||
target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece)
|
||||
FetchContent_GetProperties(gemma)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR})
|
||||
target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
```
|
||||
|
||||
### Building gemma.cpp as a Library
|
||||
|
||||
gemma.cpp can also be used as a library dependency in your own project. The
|
||||
shared library artifact can be built by modifying the make invocation to build
|
||||
the `libgemma` target instead of `gemma`.
|
||||
|
||||
> [!NOTE]
|
||||
> If you are using gemma.cpp in your own project with the `FetchContent` steps
|
||||
> in the previous section, building the library is done automatically by `cmake`
|
||||
> and this section can be skipped.
|
||||
|
||||
First, run `cmake`:
|
||||
|
||||
```sh
|
||||
cmake -B build
|
||||
```
|
||||
|
||||
Then, run `make` with the `libgemma` target:
|
||||
|
||||
```sh
|
||||
cd build
|
||||
make -j [number of parallel threads to use] libgemma
|
||||
```
|
||||
|
||||
If this is successful, you should now have a `libgemma` library file in the
|
||||
`build/` directory. On Unix platforms, the filename is `libgemma.a`.
|
||||
|
||||
## Independent Projects Using gemma.cpp
|
||||
|
||||
Some independent projects using gemma.cpp:
|
||||
|
||||
- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python)
|
||||
- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma)
|
||||
- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project)
|
||||
|
||||
If you would like to have your project included, feel free to get in touch or
|
||||
submit a PR with a `README.md` edit.
|
||||
|
||||
## Acknowledgements and Contacts
|
||||
|
||||
gemma.cpp was started in fall 2023 by
|
||||
[Austin Huang](mailto:austinvhuang@google.com) and
|
||||
[Jan Wassenberg](mailto:janwas@google.com), and subsequently released February
|
||||
2024 thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
|
||||
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||
Fischbacher and Zoltan Szabadka. It was removed in 2025-09.
|
||||
|
||||
Gemma-2 support was implemented in June/July 2024 with the help of several
|
||||
people.
|
||||
|
||||
PaliGemma support was implemented in September 2024 with contributions from
|
||||
Daniel Keysers.
|
||||
|
||||
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
|
||||
improvements, including major gains in efficiency, since the initial release.
|
||||
|
||||
This is not an officially supported Google product.
|
||||
@@ -0,0 +1,7 @@
|
||||
# Examples
|
||||
|
||||
In this directory are some simple examples illustrating usage of `gemma.cpp` as
|
||||
a library beyond the interactive `gemma` app implemented in `run.cc`.
|
||||
|
||||
- `hello_world/` - minimal/template project for using `gemma.cpp` as a library.
|
||||
It sets up the model state and generates text for a single hard coded prompt.
|
||||
@@ -0,0 +1,186 @@
|
||||
# Gemma in PyTorch
|
||||
|
||||
**Gemma** is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They include both text-only and multimodal decoder-only large language models, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links:
|
||||
|
||||
* [Gemma on Google AI](https://ai.google.dev/gemma)
|
||||
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma-3)
|
||||
* [Gemma on Vertex AI Model Garden](https://pantheon.corp.google.com/vertex-ai/publishers/google/model-garden/gemma3)
|
||||
|
||||
This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU.
|
||||
|
||||
## Updates
|
||||
|
||||
* [March 12th, 2025 🔥] Support Gemma v3. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-3/pytorch) and [Hugging Face](https://huggingface.co/models?other=gemma_torch)
|
||||
|
||||
* [June 26th, 2024] Support Gemma v2. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-2/pytorch) and Hugging Face
|
||||
|
||||
* [April 9th, 2024] Support CodeGemma. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/codegemma/pytorch) and [Hugging Face](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11)
|
||||
|
||||
* [April 5, 2024] Support Gemma v1.1. You can find the v1.1 checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch) and [Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
|
||||
|
||||
## Download Gemma model checkpoint
|
||||
|
||||
You can find the model checkpoints on Kaggle:
|
||||
|
||||
- [Gemma 3](https://www.kaggle.com/models/google/gemma-3/pyTorch)
|
||||
- [Gemma 2](https://www.kaggle.com/models/google/gemma-2/pyTorch)
|
||||
- [Gemma](https://www.kaggle.com/models/google/gemma/pyTorch)
|
||||
|
||||
Alternatively, you can find the model checkpoints on the Hugging Face Hub [here](https://huggingface.co/models?other=gemma_torch). To download the models, go the the model repository of the model of interest and click the `Files and versions` tab, and download the model and tokenizer files. For programmatic downloading, if you have `huggingface_hub` installed, you can also run:
|
||||
|
||||
```
|
||||
huggingface-cli download google/gemma-3-4b-it-pytorch
|
||||
```
|
||||
|
||||
The following model sizes are available:
|
||||
|
||||
- **Gemma 3**:
|
||||
- **Text only**: 1b
|
||||
- **Multimodal**: 4b, 12b, 27b_v3
|
||||
- **Gemma 2**:
|
||||
- **Text only**: 2b-v2, 9b, 27b
|
||||
- **Gemma**:
|
||||
- **Text only**: 2b, 7b
|
||||
|
||||
|
||||
Note that you can choose between the 1B, 4B, 12B, and 27B variants.
|
||||
|
||||
```
|
||||
VARIANT=<1b, 2b, 2b-v2, 4b, 7b, 9b, 12b, 27b, 27b_v3>
|
||||
CKPT_PATH=<Insert ckpt path here>
|
||||
```
|
||||
|
||||
## Try it free on Colab
|
||||
|
||||
Follow the steps at
|
||||
[https://ai.google.dev/gemma/docs/pytorch_gemma](https://ai.google.dev/gemma/docs/pytorch_gemma).
|
||||
|
||||
## Try it out with PyTorch
|
||||
|
||||
Prerequisite: make sure you have setup docker permission properly as a non-root user.
|
||||
|
||||
```bash
|
||||
sudo usermod -aG docker $USER
|
||||
newgrp docker
|
||||
```
|
||||
|
||||
### Build the docker image.
|
||||
|
||||
```bash
|
||||
DOCKER_URI=gemma:${USER}
|
||||
|
||||
docker build -f docker/Dockerfile ./ -t ${DOCKER_URI}
|
||||
```
|
||||
|
||||
### Run Gemma inference on CPU.
|
||||
|
||||
> NOTE: This is a multimodal example. Use a multimodal variant.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_multimodal.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Run Gemma inference on GPU.
|
||||
|
||||
> NOTE: This is a multimodal example. Use a multimodal variant.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
--gpus all \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_multimodal.py \
|
||||
--device=cuda \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}"
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
## Try It out with PyTorch/XLA
|
||||
|
||||
### Build the docker image (CPU, TPU).
|
||||
|
||||
```bash
|
||||
DOCKER_URI=gemma_xla:${USER}
|
||||
|
||||
docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}
|
||||
```
|
||||
|
||||
### Build the docker image (GPU).
|
||||
|
||||
```bash
|
||||
DOCKER_URI=gemma_xla_gpu:${USER}
|
||||
|
||||
docker build -f docker/xla_gpu.Dockerfile ./ -t ${DOCKER_URI}
|
||||
```
|
||||
|
||||
### Run Gemma inference on CPU.
|
||||
|
||||
> NOTE: This is a multimodal example. Use a multimodal variant.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
--shm-size 4gb \
|
||||
-e PJRT_DEVICE=CPU \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_xla.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Run Gemma inference on TPU.
|
||||
|
||||
Note: be sure to use the docker container built from `xla.Dockerfile`.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
--shm-size 4gb \
|
||||
-e PJRT_DEVICE=TPU \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_xla.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Run Gemma inference on GPU.
|
||||
|
||||
Note: be sure to use the docker container built from `xla_gpu.Dockerfile`.
|
||||
|
||||
```bash
|
||||
docker run -t --rm --privileged \
|
||||
--shm-size=16g --net=host --gpus all \
|
||||
-e USE_CUDA=1 \
|
||||
-e PJRT_DEVICE=CUDA \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_xla.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Tokenizer Notes
|
||||
|
||||
99 unused tokens are reserved in the pretrained tokenizer model to assist with more efficient training/fine-tuning. Unused tokens are in the string format of `<unused[0-97]>` with token id range of `[7-104]`.
|
||||
|
||||
```
|
||||
"<unused0>": 7,
|
||||
"<unused1>": 8,
|
||||
"<unused2>": 9,
|
||||
...
|
||||
"<unused98>": 104,
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an officially supported Google product.
|
||||
@@ -0,0 +1,107 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import contextlib
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from absl import app, flags
|
||||
|
||||
from gemma import config
|
||||
from gemma import model as gemma_model
|
||||
|
||||
# Define flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True)
|
||||
flags.DEFINE_string('variant', '4b', 'Model variant.')
|
||||
flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
|
||||
flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.')
|
||||
flags.DEFINE_integer('seed', 12345, 'Random seed.')
|
||||
flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
|
||||
flags.DEFINE_string('prompt', 'What are large language models?', 'Input prompt for the model.')
|
||||
|
||||
# Define valid text only model variants
|
||||
_VALID_MODEL_VARIANTS = ['2b', '2b-v2', '7b', '9b', '27b', '1b']
|
||||
|
||||
# Define valid devices
|
||||
_VALID_DEVICES = ['cpu', 'cuda']
|
||||
|
||||
# Validator function for the 'variant' flag
|
||||
def validate_variant(variant):
|
||||
if variant not in _VALID_MODEL_VARIANTS:
|
||||
raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}')
|
||||
return True
|
||||
|
||||
# Validator function for the 'device' flag
|
||||
def validate_device(device):
|
||||
if device not in _VALID_DEVICES:
|
||||
raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}')
|
||||
return True
|
||||
|
||||
# Register the validator for the 'variant' flag
|
||||
flags.register_validator('variant', validate_variant, message='Invalid model variant.')
|
||||
|
||||
# Register the validator for the 'device' flag
|
||||
flags.register_validator('device', validate_device, message='Invalid device.')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_tensor_type(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
def main(_):
|
||||
# Construct the model config.
|
||||
model_config = config.get_model_config(FLAGS.variant)
|
||||
model_config.dtype = "float32"
|
||||
model_config.quant = FLAGS.quant
|
||||
|
||||
# Seed random.
|
||||
random.seed(FLAGS.seed)
|
||||
np.random.seed(FLAGS.seed)
|
||||
torch.manual_seed(FLAGS.seed)
|
||||
|
||||
# Create the model and load the weights.
|
||||
device = torch.device(FLAGS.device)
|
||||
with _set_default_tensor_type(model_config.get_dtype()):
|
||||
model = gemma_model.GemmaForCausalLM(model_config)
|
||||
model.load_weights(FLAGS.ckpt)
|
||||
model = model.to(device).eval()
|
||||
print("Model loading done")
|
||||
|
||||
# Generate the response.
|
||||
result = model.generate(FLAGS.prompt, device, output_len=FLAGS.output_len)
|
||||
|
||||
# Print the prompts and results.
|
||||
print('======================================')
|
||||
print(f'PROMPT: {FLAGS.prompt}')
|
||||
print(f'RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
||||
|
||||
# How to run this script:
|
||||
|
||||
# Example command (replace with your actual paths and values):
|
||||
# python scripts/run.py --device=cpu --ckpt=/path/to/your/pytorch_checkpoint/model.ckpt --output_len=2 --prompt="The name of the capital of Italy is"
|
||||
# Important:
|
||||
# - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file.
|
||||
# - Choose the correct --variant (model size).
|
||||
# - Use --device=cuda if you have a GPU; otherwise, use --device=cpu.
|
||||
@@ -0,0 +1,197 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import contextlib
|
||||
import random
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from gemma import config
|
||||
from gemma import gemma3_model
|
||||
|
||||
# Define flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_CKPT = flags.DEFINE_string(
|
||||
'ckpt', None, 'Path to the checkpoint file.', required=True
|
||||
)
|
||||
_VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.')
|
||||
_DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
|
||||
_OUTPUT_LEN = flags.DEFINE_integer(
|
||||
'output_len', 10, 'Length of the output sequence.'
|
||||
)
|
||||
_SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.')
|
||||
_QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
|
||||
|
||||
# Define valid multimodal model variants
|
||||
_VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3']
|
||||
|
||||
# Define valid devices
|
||||
_VALID_DEVICES = ['cpu', 'cuda']
|
||||
|
||||
|
||||
# Validator function for the 'variant' flag
|
||||
def validate_variant(variant):
|
||||
if variant not in _VALID_MODEL_VARIANTS:
|
||||
raise ValueError(
|
||||
f'Invalid variant: {variant}. Valid variants are:'
|
||||
f' {_VALID_MODEL_VARIANTS}'
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# Validator function for the 'device' flag
|
||||
def validate_device(device):
|
||||
if device not in _VALID_DEVICES:
|
||||
raise ValueError(
|
||||
f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}'
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# Register the validator for the 'variant' flag
|
||||
flags.register_validator(
|
||||
'variant', validate_variant, message='Invalid model variant.'
|
||||
)
|
||||
|
||||
# Register the validator for the 'device' flag
|
||||
flags.register_validator('device', validate_device, message='Invalid device.')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_tensor_type(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
|
||||
def main(_):
|
||||
# Construct the model config.
|
||||
model_config = config.get_model_config(_VARIANT.value)
|
||||
model_config.dtype = 'float32'
|
||||
model_config.quant = _QUANT.value
|
||||
image_paths = {"cow_in_beach": "scripts/images/cow_in_beach.jpg",
|
||||
"lilly": "scripts/images/lilly.jpg",
|
||||
"sunflower": "scripts/images/sunflower.JPG",
|
||||
'golden_test_image': (
|
||||
'scripts/images/test_image.jpg'
|
||||
),
|
||||
}
|
||||
|
||||
image = {}
|
||||
for key in image_paths:
|
||||
try:
|
||||
image[key] = Image.open(image_paths[key]) # Open local file
|
||||
image[key].show()
|
||||
except IOError as e:
|
||||
print(f"Error loading image: {e}")
|
||||
exit()
|
||||
|
||||
# Seed random.
|
||||
random.seed(_SEED.value)
|
||||
np.random.seed(_SEED.value)
|
||||
torch.manual_seed(_SEED.value)
|
||||
|
||||
# Create the model and load the weights.
|
||||
device = torch.device(_DEVICE.value)
|
||||
with _set_default_tensor_type(model_config.get_dtype()):
|
||||
model = gemma3_model.Gemma3ForMultimodalLM(model_config)
|
||||
model.load_state_dict(torch.load(_CKPT.value)['model_state_dict'])
|
||||
# model.load_weights(_CKPT.value)
|
||||
model = model.to(device).eval()
|
||||
print('Model loading done')
|
||||
|
||||
# Generate text only.
|
||||
result = model.generate(
|
||||
[
|
||||
[
|
||||
'<start_of_turn>user The capital of Italy'
|
||||
' is?<end_of_turn>\n<start_of_turn>model'
|
||||
],
|
||||
[
|
||||
'<start_of_turn>user What is your'
|
||||
' purpose?<end_of_turn>\n<start_of_turn>model'
|
||||
],
|
||||
],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the results.
|
||||
print('======================================')
|
||||
print(f'Text only RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
# Generate golden Gemax test image.
|
||||
result = model.generate(
|
||||
[[
|
||||
'<start_of_turn>user\n',
|
||||
image['golden_test_image'],
|
||||
'Caption this image. <end_of_turn>\n<start_of_turn>model',
|
||||
]],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the result.
|
||||
print('======================================')
|
||||
print(f'Golden test image RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
# Generate text and image.
|
||||
result = model.generate(
|
||||
[[
|
||||
'<start_of_turn>user\n',
|
||||
image['cow_in_beach'],
|
||||
(
|
||||
'The name of the animal in the image is'
|
||||
' <end_of_turn>\n<start_of_turn>model'
|
||||
),
|
||||
]],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the result.
|
||||
print('======================================')
|
||||
print(f'Single image RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
# Generate interleave text and multiple images.
|
||||
result = model.generate(
|
||||
[[
|
||||
'<start_of_turn>user\nThis image',
|
||||
image['lilly'],
|
||||
'and this image',
|
||||
image['sunflower'],
|
||||
'are similar because? <end_of_turn>\n<start_of_turn>model',
|
||||
]],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the result.
|
||||
print('======================================')
|
||||
print(f'Interleave images RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
|
||||
from gemma.config import GemmaConfig, get_model_config
|
||||
from gemma.model_xla import GemmaForCausalLM
|
||||
from gemma.tokenizer import Tokenizer
|
||||
import gemma.xla_model_parallel as xla_model_parallel
|
||||
|
||||
USE_CUDA = os.environ.get('USE_CUDA', False)
|
||||
if not USE_CUDA:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
else:
|
||||
# Choose an available port.
|
||||
with contextlib.closing(socket.socket(socket.AF_INET,
|
||||
socket.SOCK_STREAM)) as s:
|
||||
s.bind(('', 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
MASTER_PORT = str(s.getsockname()[1])
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_tensor_type(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
|
||||
def generate(
|
||||
i: int,
|
||||
model_config: GemmaConfig,
|
||||
ckpt_path: str,
|
||||
prompts: List[str],
|
||||
output_lens: List[int],
|
||||
temperatures: Union[List[float], None],
|
||||
top_ps: List[float],
|
||||
top_ks: List[int],
|
||||
seed: int
|
||||
):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if USE_CUDA:
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = MASTER_PORT
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
"nccl",
|
||||
rank=int(os.environ.get("RANK", 0)),
|
||||
world_size=int(os.environ.get("WORLD_SIZE", 1)))
|
||||
xla_model_parallel.set_g_group()
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
device = xm.xla_device()
|
||||
xm.set_rng_state(seed, device)
|
||||
|
||||
rank = xla_model_parallel.get_model_parallel_rank()
|
||||
world_size = xla_model_parallel.get_model_parallel_world_size()
|
||||
if rank > 0:
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
|
||||
# build, load and compile model.
|
||||
with _set_default_tensor_type(model_config.get_dtype()):
|
||||
model = GemmaForCausalLM(model_config, world_size, rank, device)
|
||||
model.load_weights(ckpt_path)
|
||||
model = model.to(device).eval()
|
||||
|
||||
# create tokenizer.
|
||||
tokenizer = Tokenizer(model_config.tokenizer)
|
||||
|
||||
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
|
||||
min_prompt_len = min(len(p) for p in prompt_tokens)
|
||||
|
||||
batch_size = len(prompts)
|
||||
if temperatures is not None:
|
||||
assert batch_size == len(temperatures)
|
||||
assert batch_size == len(top_ps)
|
||||
assert batch_size == len(top_ks)
|
||||
max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
|
||||
assert max_seq_len <= model_config.max_position_embeddings
|
||||
if model_config.num_key_value_heads < world_size:
|
||||
assert world_size % model_config.num_key_value_heads == 0
|
||||
n_local_heads = 1
|
||||
else:
|
||||
assert model_config.num_key_value_heads % world_size == 0
|
||||
n_local_heads = model_config.num_key_value_heads // world_size
|
||||
|
||||
# build KV caches
|
||||
kv_caches = []
|
||||
for _ in range(model_config.num_hidden_layers):
|
||||
k_cache = torch.zeros(
|
||||
size=(batch_size, max_seq_len, n_local_heads,
|
||||
model_config.head_dim),
|
||||
dtype=model_config.get_dtype(),
|
||||
device=device,
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
size=(batch_size, max_seq_len, n_local_heads,
|
||||
model_config.head_dim),
|
||||
dtype=model_config.get_dtype(),
|
||||
device=device,
|
||||
)
|
||||
kv_caches.append((k_cache, v_cache))
|
||||
|
||||
# prepare inputs
|
||||
token_ids_tensor = torch.full((batch_size, max_seq_len),
|
||||
tokenizer.pad_id,
|
||||
dtype=torch.int64)
|
||||
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
|
||||
tokenizer.pad_id,
|
||||
dtype=torch.int64)
|
||||
for i, p in enumerate(prompt_tokens):
|
||||
token_ids_tensor[i, :len(p)] = torch.tensor(p)
|
||||
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
|
||||
p[:min_prompt_len])
|
||||
token_ids_tensor = token_ids_tensor.to(device)
|
||||
prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
|
||||
input_token_ids_tensor = input_token_ids_tensor.to(device)
|
||||
input_positions_tensor = torch.arange(0, min_prompt_len,
|
||||
dtype=torch.int64).to(device)
|
||||
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
|
||||
-2.3819763e38).to(torch.float)
|
||||
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
|
||||
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
|
||||
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
|
||||
temperatures_tensor = None if not temperatures else torch.FloatTensor(temperatures).to(device)
|
||||
top_ps_tensor = torch.FloatTensor(top_ps).to(device)
|
||||
top_ks_tensor = torch.LongTensor(top_ks).to(device)
|
||||
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
|
||||
if not USE_CUDA:
|
||||
xm.mark_step()
|
||||
|
||||
# Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
|
||||
for i in range(max_seq_len - min_prompt_len):
|
||||
next_token_ids, _ = model(
|
||||
input_token_ids=input_token_ids_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
kv_write_indices=None,
|
||||
kv_caches=kv_caches,
|
||||
mask=curr_mask_tensor,
|
||||
output_positions=output_positions_tensor,
|
||||
temperatures=temperatures_tensor,
|
||||
top_ps=top_ps_tensor,
|
||||
top_ks=top_ks_tensor,
|
||||
)
|
||||
curr_prompt_mask = prompt_mask_tensor.index_select(
|
||||
1, output_index).squeeze(dim=1)
|
||||
curr_token_ids = token_ids_tensor.index_select(
|
||||
1, output_index).squeeze(dim=1)
|
||||
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
|
||||
next_token_ids).unsqueeze(dim=1)
|
||||
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
|
||||
|
||||
input_token_ids_tensor = output_token_ids
|
||||
input_positions_tensor = output_index.unsqueeze(dim=-1)
|
||||
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
|
||||
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
|
||||
output_index = output_index + 1
|
||||
if not USE_CUDA:
|
||||
xm.mark_step()
|
||||
|
||||
# Detokenization.
|
||||
token_ids = token_ids_tensor.tolist()
|
||||
results = []
|
||||
for i, tokens in enumerate(token_ids):
|
||||
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
|
||||
output_lens[i]]
|
||||
if tokenizer.eos_id in trimmed_output:
|
||||
eos_index = trimmed_output.index(tokenizer.eos_id)
|
||||
trimmed_output = trimmed_output[:eos_index]
|
||||
results.append(tokenizer.decode(trimmed_output))
|
||||
|
||||
for prompt, result in zip(prompts, results):
|
||||
print('======================================')
|
||||
print(f'PROMPT: {prompt}')
|
||||
print(f'RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
|
||||
def main(args):
|
||||
model_config = get_model_config(args.variant)
|
||||
model_config.quant = args.quant
|
||||
|
||||
prompts = [args.prompt]
|
||||
n = len(prompts)
|
||||
output_lengths = [args.output_len] * n
|
||||
temperatures = [0.95] * n
|
||||
top_ps = [1.0] * n
|
||||
top_ks = [100] * n
|
||||
|
||||
if USE_CUDA:
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = MASTER_PORT
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
"nccl",
|
||||
rank=int(os.environ.get("RANK", 0)),
|
||||
world_size=int(os.environ.get("WORLD_SIZE", 1)))
|
||||
xla_model_parallel.set_g_group()
|
||||
torch.multiprocessing.spawn(
|
||||
generate,
|
||||
args=(
|
||||
model_config,
|
||||
args.ckpt,
|
||||
prompts,
|
||||
output_lengths,
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
args.seed,
|
||||
),
|
||||
)
|
||||
else:
|
||||
xmp.spawn(
|
||||
generate,
|
||||
args=(
|
||||
model_config,
|
||||
args.ckpt,
|
||||
prompts,
|
||||
output_lengths,
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
args.seed,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt", type=str, required=True)
|
||||
parser.add_argument("--variant",
|
||||
type=str,
|
||||
default="2b",
|
||||
choices=["2b", "2b-v2", "7b", "9b", "27b"])
|
||||
parser.add_argument("--output_len", type=int, default=4)
|
||||
parser.add_argument("--seed", type=int, default=12345)
|
||||
parser.add_argument("--quant", action='store_true')
|
||||
parser.add_argument("--prompt", type=str, default="The meaning of life is")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user