docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,161 @@
|
||||
# Gemma 4 — Hugging Face Canonical Tooling
|
||||
|
||||
Downloaded April 2026. First-party Google/HF content only. No weights, no third-party fine-tunes.
|
||||
|
||||
## What's here
|
||||
|
||||
### `model-cards/`
|
||||
Verbatim `README.md` from every `google/gemma-4-*` repo (raw endpoint, ungated). Plus the chat template and tokenizer config for two representative variants (31B-it and E4B-it). All eight model cards have identical body text; they differ only in the `pipeline_tag:` YAML frontmatter and size-specific tables.
|
||||
|
||||
| File | What it demonstrates |
|
||||
|------|----------------------|
|
||||
| `gemma-4-31B-it-README.md` | Flagship dense (33B) instruction-tuned. Full "how to use" from Google+HF. |
|
||||
| `gemma-4-31B-README.md` | Base (pretrained) variant of the above. |
|
||||
| `gemma-4-26B-A4B-it-README.md` | MoE (26B params, 4B active) instruction-tuned. The "A4B" = 4B active. |
|
||||
| `gemma-4-26B-A4B-README.md` | Base MoE. |
|
||||
| `gemma-4-E4B-it-README.md` | Edge-sized 8B instruction-tuned. Multimodal including audio. |
|
||||
| `gemma-4-E4B-README.md` | Base E4B. |
|
||||
| `gemma-4-E2B-it-README.md` | Smallest (5B) instruction-tuned, mobile-targeted. |
|
||||
| `gemma-4-E2B-README.md` | Base E2B. |
|
||||
| `gemma-4-31B-it-chat_template.jinja` | **Canonical chat template.** 16KB Jinja — handles system/user/model/tool roles, thinking channel, tool calls, image/audio/video tokens. |
|
||||
| `gemma-4-E4B-it-chat_template.jinja` | Near-identical to 31B's (131-byte difference — likely one whitespace-sensitive thing around audio handling). |
|
||||
| `gemma-4-31B-it-tokenizer_config.json` | **Special-token inventory + `response_schema` regex machinery.** See "New capabilities" below. |
|
||||
| `gemma-4-E4B-it-tokenizer_config.json` | Same shape. |
|
||||
|
||||
### `transformers/`
|
||||
Files under `src/transformers/models/gemma4/` on `huggingface/transformers@main`. Full files for small ones; outlines (signatures + first 12 lines per class/def) for the two large ones.
|
||||
|
||||
| File | Lines | What |
|
||||
|------|-------|------|
|
||||
| `__init__.py` | 33 | Module exports |
|
||||
| `configuration_gemma4.py` | 352 | `Gemma4Config`, `Gemma4TextConfig`, `Gemma4AudioConfig`, `Gemma4VisionConfig` — all hyperparams |
|
||||
| `processing_gemma4.py` | 366 | `Gemma4Processor` — the thing `AutoProcessor.from_pretrained` returns. Includes `parse_response()` |
|
||||
| `feature_extraction_gemma4.py` | 298 | Audio feature extraction (mel spec, padding) |
|
||||
| `image_processing_gemma4.py` | 220 | Tensor-backed image preprocessing |
|
||||
| `image_processing_pil_gemma4.py` | 278 | PIL-backed variant (slower fallback) |
|
||||
| `video_processing_gemma4.py` | 237 | Frame sampling + stitching to image tokens |
|
||||
| `modeling_gemma4-OUTLINE.py` | 723 | Outline of the 2657-line modeling file (43 classes: attention, MoE, audio encoder, vision tower, all LM heads) |
|
||||
| `modular_gemma4-OUTLINE.py` | 563 | Outline of the modular source file — shows Gemma4 **inherits from Gemma3n classes** (RMSNorm, attention blocks etc.) confirming the 3n→4 lineage |
|
||||
|
||||
Full files: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gemma4
|
||||
|
||||
### `recipes/`
|
||||
From `huggingface/huggingface-gemma-recipes` — the canonical HF recipe repo. The only Gemma 4-specific recipe as of April 2026 is one notebook; the rest is Gemma 3n which is architecturally the parent of Gemma 4.
|
||||
|
||||
| File | What |
|
||||
|------|------|
|
||||
| `notebooks/Gemma4_E2B-Multimodal.ipynb` | **The one first-party Gemma-4 recipe.** Original ipynb. 36 cells: image, video, audio, function calling, object detection with `box_2d`, any-to-any pipeline, captioning. |
|
||||
| `notebooks/Gemma4_E2B-Multimodal-extracted.py` | Same notebook flattened to readable .py for grep/diff. |
|
||||
| `scripts/ft_gemma3n_image_trl.py` | TRL SFT fine-tune of Gemma 3n on images. Direct precursor to Gemma 4 SFT. |
|
||||
| `scripts/ft_gemma3n_image_vt.py` | Vision+text fine-tune without TRL (pure Transformers Trainer). |
|
||||
| `scripts/ft_gemma3n_audio_vt.py` | Audio+text fine-tune. |
|
||||
| `scripts/gemma3n_fine_tuning_on_all_modalities.py` | All-modalities SFT script — template for full Gemma-4 all-modal SFT. |
|
||||
| `scripts/carla_vlm_gemma.py` | CARLA driving sim VLM example using Gemma. |
|
||||
|
||||
### `trl/`
|
||||
**Empty as of April 2026.** Searched `huggingface/trl/examples/scripts` — only `sft_gemma3.py` and `sft_vlm_gemma3.py` exist, no gemma4 yet. The gemma-recipes repo's `ft_gemma3n_image_trl.py` is the closest first-party TRL pattern; it is saved under `recipes/scripts/` above.
|
||||
|
||||
### `peft/`
|
||||
**Empty as of April 2026.** `huggingface/peft/examples` has no gemma-specific directory. The canonical HF PEFT guide for Gemma is the blog post `gemma-peft.md`, saved under `blog/` below. It covers Gemma 1 but the LoRA target-module patterns apply unchanged to Gemma 4 (same `q_proj/k_proj/v_proj/o_proj` naming).
|
||||
|
||||
### `blog/`
|
||||
| File | What |
|
||||
|------|------|
|
||||
| `gemma4-blog.md` | **"Welcome Gemma 4: Frontier multimodal intelligence on device"** — the HF launch blog. 764 lines. Authored by merve. Covers architecture, capabilities, transformers usage, HF Inference API, llama.cpp/MLX quantization, thinking mode examples. |
|
||||
| `gemma-peft-blog.md` | "Fine-Tuning Gemma Models in Hugging Face" — the PEFT/LoRA recipe blog (gemma-agnostic, target modules unchanged for Gemma 4). |
|
||||
|
||||
### `spaces/`
|
||||
The official HF-run interactive demo Spaces.
|
||||
|
||||
| File | What |
|
||||
|------|------|
|
||||
| `huggingface-projects_gemma-4-31b-it-app.py` | Official 31B demo (Gradio 6 chat + multimodal). |
|
||||
| `huggingface-projects_gemma-4-e4b-it-app.py` | Official E4B demo. **More illustrative** — shows the full multimodal+thinking pattern in ~320 lines. |
|
||||
| `*-requirements.txt` | Pinned deps. **`transformers==5.5.4`** (as of 2026-04-18) — that's the minimum version for Gemma 4 in transformers main line. |
|
||||
|
||||
---
|
||||
|
||||
## New capabilities the HF integration exposes that weren't in the existing corpus
|
||||
|
||||
1. **`AutoModelForMultimodalLM`** — new transformers AutoClass, not `AutoModelForCausalLM`. Required to get any-to-any routing (text+image+audio+video in, text out). The corpus's `CORPUS_capabilities.md` should note this.
|
||||
|
||||
2. **`processor.parse_response(text) -> dict`** — built into `Gemma4Processor`. Returns `{thinking, content, tool_calls}` parsed from raw decoded output. Driven by regexes declared in `tokenizer_config.json` under `response_schema` (new HF feature using `x-regex`, `x-regex-iterator`, and a custom `x-parser: gemma4-tool-call`). **You no longer need to hand-roll tool-call regex parsing** if you use the HF processor — this is the HF-canonical replacement for the manual parsing done in `CORPUS_tool_calling_format.md`.
|
||||
|
||||
3. **`enable_thinking=True`** — a kwarg to `processor.apply_chat_template()`. When set, injects `<|think|>` at the top of the system turn. **This is how you turn reasoning mode on** through the HF API. Not documented in the existing corpus.
|
||||
|
||||
4. **`load_audio_from_video=True`** — another `apply_chat_template` kwarg. Pulls the audio track out of a video URL and feeds it as audio tokens alongside sampled frames. Only relevant for E2B/E4B which have audio; the notebook comment explicitly calls this out.
|
||||
|
||||
5. **`pipeline("any-to-any", model=...)`** — a new HF pipeline task registered for Gemma 4. Accepts the chat-style messages list directly. Easiest one-liner for multimodal inference.
|
||||
|
||||
6. **Object detection via `box_2d` JSON** — prompting with "What's the bounding box for the X?" returns `[{"box_2d": [ymin, xmin, ymax, xmax], "label": "..."}]` in a 1000x1000 normalized coordinate frame, with images resized to multiples of 48 pixels. This is a Gemma-4-specific convention the notebook demonstrates. Corpus doesn't cover this.
|
||||
|
||||
7. **Thinking delimiters are `<|channel>thought...<channel|>`** — not `<thinking>...</thinking>` like some other open-weights models. The Space app explicitly strips these to pass to Gradio 6's `reasoning_tags` for collapsible thinking UI.
|
||||
|
||||
8. **Breaking change in role/turn markers vs Gemma 3** — Gemma 3 used `<start_of_turn>user ... <end_of_turn>`. Gemma 4 uses `<|turn>user\n ... <turn|>`. Tokenizer config:
|
||||
- `sot_token`: `<|turn>` (start of turn)
|
||||
- `eot_token`: `<turn|>` (end of turn)
|
||||
- Role after `<|turn>` can be `system`, `user`, `model`, or `tool`.
|
||||
- `enable_thinking` injects a `<|think|>` marker into the first system turn.
|
||||
Anything in the homelab that hard-codes `<start_of_turn>` for Gemma needs to branch on family version. Worth adding to `GOTCHAS.md`.
|
||||
|
||||
---
|
||||
|
||||
## Canonical chat template format
|
||||
|
||||
**Source of truth:** the two `.jinja` files in `model-cards/`. Use them directly — **do not reimplement.** The tokenizer loads them automatically:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor
|
||||
processor = AutoProcessor.from_pretrained("google/gemma-4-E4B-it")
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tools=[WEATHER_TOOL], # optional; OpenAI-style tool schema
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True, # turns on reasoning, injects <|think|>
|
||||
load_audio_from_video=False, # only for video inputs
|
||||
)
|
||||
output = model.generate(**inputs, max_new_tokens=1000)
|
||||
generated = processor.decode(output[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
|
||||
result = processor.parse_response(generated)
|
||||
# → {"thinking": "...", "content": "...", "tool_calls": [...]}
|
||||
```
|
||||
|
||||
### Wire format that the template produces
|
||||
|
||||
```
|
||||
<bos><|turn>system
|
||||
<|think|>
|
||||
{system prompt here if any}
|
||||
<|tool>declaration:get_weather{city:{type:<|"|>STRING<|"|>,description:<|"|>...<|"|>}}<tool|>
|
||||
<turn|>
|
||||
<|turn>user
|
||||
{user text}
|
||||
<|image|> ← placeholder for each image
|
||||
<|audio|> ← placeholder for each audio
|
||||
<|video|> ← placeholder for each video
|
||||
<turn|>
|
||||
<|turn>model
|
||||
<|channel>thought
|
||||
{reasoning text}
|
||||
<channel|>
|
||||
<|tool_call>call:get_weather{city:<|"|>London<|"|>}<tool_call|>
|
||||
<|tool_response>response:get_weather{temperature:15}<tool_response|>
|
||||
{final content}
|
||||
<turn|>
|
||||
```
|
||||
|
||||
Every Gemma-4-specific token appears in `tokenizer_config.json`. The `apply_chat_template` call + the `response_schema` + `parse_response()` round-trip means **homelab code should never hand-emit these tokens** — always go through the processor.
|
||||
|
||||
---
|
||||
|
||||
## Source URLs (first-party only)
|
||||
|
||||
- Model collection: https://huggingface.co/collections/google/gemma-4
|
||||
- transformers gemma4 dir: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gemma4
|
||||
- Recipes repo: https://github.com/huggingface/huggingface-gemma-recipes
|
||||
- Launch blog: https://huggingface.co/blog/gemma4
|
||||
- Official 31B Space: https://huggingface.co/spaces/huggingface-projects/gemma-4-31b-it
|
||||
- Official E4B Space: https://huggingface.co/spaces/huggingface-projects/gemma-4-e4b-it
|
||||
@@ -0,0 +1,207 @@
|
||||
---
|
||||
title: Fine-Tuning Gemma Models in Hugging Face
|
||||
thumbnail: /blog/assets/gemma-peft/thumbnail.png
|
||||
authors:
|
||||
- user: svaibhav
|
||||
guest: true
|
||||
- user: alanwaketan
|
||||
guest: true
|
||||
- user: ybelkada
|
||||
- user: ArthurZ
|
||||
---
|
||||
|
||||
# Fine-Tuning Gemma Models in Hugging Face
|
||||
|
||||
We recently announced that [Gemma](https://huggingface.co/blog/gemma), the open weights language model from Google Deepmind, is available for the broader open-source community via Hugging Face. It’s available in 2 billion and 7 billion parameter sizes with pretrained and instruction-tuned flavors. It’s available on Hugging Face, supported in TGI, and easily accessible for deployment and fine-tuning in the Vertex Model Garden and Google Kubernetes Engine.
|
||||
|
||||
<div class="flex items-center justify-center">
|
||||
<img src="/blog/assets/gemma-peft/Gemma-peft.png" alt="Gemma Deploy">
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
The Gemma family of models also happens to be well suited for prototyping and experimentation using the free GPU resource available via Colab. In this post we will briefly review how you can do [Parameter Efficient FineTuning (PEFT)](https://huggingface.co/blog/peft) for Gemma models, using the Hugging Face Transformers and PEFT libraries on GPUs and Cloud TPUs for anyone who wants to fine-tune Gemma models on their own dataset.
|
||||
|
||||
|
||||
|
||||
## Why PEFT?
|
||||
|
||||
The default (full weight) training for language models, even for modest sizes, tends to be memory and compute-intensive. On one hand, it can be prohibitive for users relying on openly available compute platforms for learning and experimentation, such as Colab or Kaggle. On the other hand, and even for enterprise users, the cost of adapting these models for different domains is an important metric to optimize. PEFT, or parameter-efficient fine tuning, is a popular technique to accomplish this at low cost.
|
||||
|
||||
## PyTorch on GPU and TPU
|
||||
|
||||
Gemma models in Hugging Face `transformers` are optimized for both PyTorch and PyTorch/XLA. This enables both TPU and GPU users to access and experiment with Gemma models as needed. Together with the Gemma release, we have also improved the [FSDP](https://engineering.fb.com/2021/07/15/open-source/fsdp/) experience for PyTorch/XLA in Hugging Face. This [FSDP via SPMD](https://github.com/pytorch/xla/issues/6379) integration also allows other Hugging Face models to take advantage of TPU acceleration via PyTorch/XLA. In this post, we will focus on PEFT, and more specifically on Low-Rank Adaptation (LoRA), for Gemma models. For a more comprehensive set of LoRA techniques, we encourage readers to review the [Scaling Down to Scale Up, from Lialin et al.](https://arxiv.org/pdf/2303.15647.pdf) and [this excellent post](https://pytorch.org/blog/finetune-llms/) post by Belkada et al.
|
||||
|
||||
## Low-Rank Adaptation for Large Language Models
|
||||
|
||||
Low-Rank Adaptation (LoRA) is one of the parameter-efficient fine-tuning techniques for large language models (LLMs). It addresses just a fraction of the total number of model parameters to be fine-tuned, by freezing the original model and only training adapter layers that are decomposed into low-rank matrices. The [PEFT library](https://github.com/huggingface/peft) provides an easy abstraction that allows users to select the model layers where adapter weights should be applied.
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
```
|
||||
|
||||
In this snippet, we refer to all `nn.Linear` layers as the target layers to be adapted.
|
||||
|
||||
In the following example, we will leverage [QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes), from [Dettmers et al.](https://arxiv.org/abs/2305.14314), in order to quantize the base model in 4-bit precision for a more memory efficient fine-tuning protocol. The model can be loaded with QLoRA by first installing the `bitsandbytes` library on your environment, and then passing a `BitsAndBytesConfig` object to `from_pretrained` when loading the model.
|
||||
|
||||
## Before we begin
|
||||
|
||||
In order to access Gemma model artifacts, users are required to accept [the consent form](https://huggingface.co/google/gemma-7b-it).
|
||||
Now let’s get started with the implementation.
|
||||
|
||||
## Learning to quote
|
||||
|
||||
Assuming that you have submitted the consent form, you can access the model artifacts from the [Hugging Face Hub](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
|
||||
|
||||
We start by downloading the model and the tokenizer. We also include a `BitsAndBytesConfig` for weight only quantization.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
model_id = "google/gemma-2b"
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])
|
||||
```
|
||||
|
||||
Now we test the model before starting the finetuning, using a famous quote:
|
||||
|
||||
|
||||
```python
|
||||
text = "Quote: Imagination is more"
|
||||
device = "cuda:0"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=20)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
The model does a reasonable completion with some extra tokens:
|
||||
```
|
||||
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
|
||||
|
||||
-Albert Einstein
|
||||
|
||||
I
|
||||
```
|
||||
|
||||
But this is not exactly the format we would love the answer to be. Let’s see if we can use fine-tuning to teach the model to generate the answer in the following format.
|
||||
|
||||
```
|
||||
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
|
||||
|
||||
Author: Albert Einstein
|
||||
```
|
||||
|
||||
To begin with, let's select an English quotes dataset [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
data = load_dataset("Abirate/english_quotes")
|
||||
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
|
||||
```
|
||||
|
||||
Now let’s finetune this model using the LoRA config stated above:
|
||||
|
||||
```python
|
||||
import transformers
|
||||
from trl import SFTTrainer
|
||||
|
||||
def formatting_func(example):
|
||||
text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}<eos>"
|
||||
return [text]
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=data["train"],
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_steps=2,
|
||||
max_steps=10,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=1,
|
||||
output_dir="outputs",
|
||||
optim="paged_adamw_8bit"
|
||||
),
|
||||
peft_config=lora_config,
|
||||
formatting_func=formatting_func,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Finally, we are ready to test the model once more with the same prompt we have used earlier:
|
||||
|
||||
```python
|
||||
text = "Quote: Imagination is"
|
||||
device = "cuda:0"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=20)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
This time we get the response in the format we like:
|
||||
|
||||
|
||||
```
|
||||
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
|
||||
Author: Albert Einstein
|
||||
```
|
||||
|
||||
|
||||
## Accelerate with FSDP via SPMD on TPU
|
||||
|
||||
As mentioned earlier, Hugging Face `transformers` now supports PyTorch/XLA’s latest FSDP implementation. This can greatly accelerate the fine-tuning speed. To enable that, one just needs to add a FSDP config to the `transformers.Trainer`:
|
||||
|
||||
```python
|
||||
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
|
||||
|
||||
# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
|
||||
fsdp_config = {
|
||||
"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
|
||||
"xla": True,
|
||||
"xla_fsdp_v2": True,
|
||||
"xla_fsdp_grad_ckpt": True
|
||||
}
|
||||
|
||||
# Finally, set up the trainer and train the model.
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=data,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=64, # This is actually the global batch size for SPMD.
|
||||
num_train_epochs=100,
|
||||
max_steps=-1,
|
||||
output_dir="./output",
|
||||
optim="adafactor",
|
||||
logging_steps=1,
|
||||
dataloader_drop_last = True, # Required for SPMD.
|
||||
fsdp="full_shard",
|
||||
fsdp_config=fsdp_config,
|
||||
),
|
||||
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
We walked through this simple example adapted from the source notebook to illustrate the LoRA finetuning method applied to Gemma models. The full colab for GPU can be found [here](https://huggingface.co/google/gemma-7b/blob/main/examples/notebook_sft_peft.ipynb), and the full script for TPU can be found [here](https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py). We are excited about the endless possibilities for research and learning thanks to this recent addition to our open source ecosystem. We encourage users to also visit the [Gemma documentation](https://huggingface.co/docs/transformers/v4.38.0/en/model_doc/gemma), as well as our [launch blog](https://huggingface.co/blog/gemma) for more examples to train, finetune and deploy Gemma models.
|
||||
|
||||
|
||||
@@ -0,0 +1,764 @@
|
||||
---
|
||||
title: "Welcome Gemma 4: Frontier multimodal intelligence on device"
|
||||
thumbnail: /blog/assets/gemma4/thumbnail.png
|
||||
authors:
|
||||
- user: merve
|
||||
- user: pcuenq
|
||||
- user: sergiopaniego
|
||||
- user: burtenshaw
|
||||
- user: Steveeeeeeen
|
||||
- user: alvarobartt
|
||||
- user: SaylorTwift
|
||||
---
|
||||
|
||||
# Welcome Gemma 4: Frontier multimodal intelligence on device
|
||||
|
||||
The Gemma 4 family of multimodal models by Google DeepMind is out on Hugging Face, with support for your favorite agents, inference engines, and fine-tuning libraries 🤗
|
||||
|
||||
These models are the real deal: truly open with Apache 2 licenses, high quality with pareto frontier arena scores, multimodal including audio, and sizes you can use _everywhere_ including on-device. Gemma 4 builds on advances from previous families and makes them click together. In our tests with pre-release checkpoints we have been impressed by their capabilities, to the extent that we struggled to find good fine-tuning examples because they are _so good_ out of the box.
|
||||
|
||||
We collaborated with Google and the community to make them available everywhere: transformers, llama.cpp, MLX, WebGPU, Rust; you name it. This blog post will show you how to build with [your favorite tools](https://huggingface.co/collections/google/gemma-4) so let us know what you think!
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [What is New with Gemma 4?](#what-is-new-with-gemma-4)
|
||||
- [Overview of Capabilities and Architecture](#overview-of-capabilities-and-architecture)
|
||||
- [Architecture at a Glance](#architecture-at-a-glance)
|
||||
- [Per-Layer Embeddings (PLE)](#per-layer-embeddings-ple)
|
||||
- [Shared KV Cache](#shared-kv-cache)
|
||||
- [Multimodal Capabilities](#multimodal-capabilities)
|
||||
- [Deploy Anywhere](#deploy-anywhere)
|
||||
- [transformers](#transformers)
|
||||
- [Llama.cpp](#llamacpp)
|
||||
- [Plug in to your local agent](#Plug-in-your-local-agent)
|
||||
- [transformers.js](#transformersjs)
|
||||
- [MLX](#mlx)
|
||||
- [Mistral.rs](#mistralrs)
|
||||
- [Fine-tuning & Demos](#fine-tuning--demos)
|
||||
- [Fine-tuning with TRL](#fine-tuning-with-trl)
|
||||
- [Fine-tuning with TRL on Vertex AI](#fine-tuning-with-trl-on-vertex-ai)
|
||||
- [Fine-tuning with Unsloth Studio](#fine-tuning-with-unsloth-studio)
|
||||
- [Try Gemma 4](#try-gemma-4)
|
||||
- [Benchmark Results](#benchmark-results)
|
||||
- [Acknowledgements](#acknowledgements)
|
||||
|
||||
# What is new with Gemma 4?
|
||||
|
||||
Similar to Gemma-3n, Gemma 4 supports image, text, and audio inputs, and generates text responses. The text decoder is based on the Gemma model with support for long context windows. The image encoder is similar to the one from Gemma 3 but with two crucial improvements: variable aspect ratios, and configurable number of image token inputs to find your sweet spot between speed, memory, and quality. All models support images (or video) and text inputs, while the small variants (E2B and E4B) support audio as well.
|
||||
|
||||
Gemma 4 comes in four sizes, all base and instruction fine-tuned:
|
||||
|
||||
| Model | Parameter Size | Context Window | Checkpoints |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| Gemma 4 E2B | 2.3B effective, 5.1B with embeddings | 128k | [base](https://huggingface.co/google/gemma-4-E2B), [IT](https://huggingface.co/google/gemma-4-E2B-it) |
|
||||
| Gemma 4 E4B | 4.5B effective, 8B with embeddings | 128k | [base](https://huggingface.co/google/gemma-4-E4B), [IT](https://huggingface.co/google/gemma-4-E4B-it) |
|
||||
| Gemma 4 31B | 31B dense model | 256K | [base](https://huggingface.co/google/gemma-4-31B), [IT](https://huggingface.co/google/gemma-4-31B-it) |
|
||||
| Gemma 4 26B A4B | mixture-of-experts with 4B activated/26B total parameters | 256K | [base](https://huggingface.co/google/gemma-4-26B-A4B), [IT](https://huggingface.co/google/gemma-4-26B-A4B-it) |
|
||||
|
||||
## Overview of Capabilities and Architecture
|
||||
|
||||
Gemma 4 leverages several architecture components used in previous Gemma versions and other open models, and leaves out complex or inconclusive features such as Altup. The combination is a mix designed to be highly compatible across libraries and devices, that can efficiently support long context and agentic use cases, whilst being ideal for quantization.
|
||||
|
||||
As shown in the benchmarks above, this feature mix (combined with the training data and recipe) enables the 31B dense model to achieve an estimated LMArena score (text only) of 1452, while the 26B MoE reaches 1441 with just 4B active parameters 🤯. As we'll see, multimodal operation is comparatively as good as text generation, at least in informal and subjective tests.
|
||||
|
||||
These are the main architecture characteristics in Gemma 4:
|
||||
|
||||
* Alternating **local sliding-window** and **global full-context** attention layers. Smaller dense models use sliding windows of 512 tokens while larger models use 1024 tokens.
|
||||
* **Dual RoPE** configurations: standard RoPE for sliding layers, pruned RoPE for global layers, to enable longer context.
|
||||
* **Per-Layer Embeddings (PLE)**: a second embedding table that feeds a small residual signal into every decoder layer.
|
||||
* **Shared KV Cache**: the last N layers of the model reuse key-value states from earlier layers, eliminating redundant KV projections.
|
||||
* **Vision encoder**: uses learned 2D positions and multidimensional RoPE. Preserves the original aspect ratios and can encode images to a few different token budgets (70, 140, 280, 560, 1120).
|
||||
* **Audio encoder**: USM-style conformer with the same base architecture as the one in Gemma-3n.
|
||||
|
||||
#### Per-Layer Embeddings (PLE)
|
||||
|
||||
One of the most distinctive features in smaller Gemma 4 models is Per-Layer Embeddings (PLE), which was introduced previously in Gemma-3n. In a standard transformer, each token gets a single embedding vector at input, and the same initial representation is what the residual stream builds on across all layers, forcing the embedding to frontload everything the model might need. PLE adds a parallel, lower-dimensional conditioning pathway alongside the main residual stream. For each token, it produces a small dedicated vector for every layer by combining two signals: a token-identity component (from an embedding lookup) and a context-aware component (from a learned projection of the main embeddings). Each decoder layer then uses its corresponding vector to modulate the hidden states via a lightweight residual block after attention and feed-forward. This gives each layer its own channel to receive token-specific information only when it becomes relevant, rather than requiring everything to be packed into a single upfront embedding. Because the PLE dimension is much smaller than the main hidden size, this adds meaningful per-layer specialization at modest parameter cost. For multimodal inputs (images, audio, video), PLE is computed before soft tokens are merged into the embedding sequence — since PLE relies on token IDs that are lost once multimodal features replace the placeholders. Multimodal positions use the pad token ID, effectively receiving neutral per-layer signals.
|
||||
|
||||
#### Shared KV Cache
|
||||
|
||||
The **shared KV cache** is an efficiency optimization that reduces both compute and memory during inference. The last `num_kv_shared_layers` layers of the model don't compute their own key and value projections. Instead, they **reuse** the K and V tensors from the last non-shared layer of the same attention type (sliding or full).
|
||||
|
||||
In practice, this has a minimal impact on quality while being much more efficient (in terms of both memory and compute) for long context generation and on-device use.
|
||||
|
||||
## Multimodal Capabilities
|
||||
|
||||
We saw in our tests that Gemma 4 supports comprehensive multimodal capabilities out of the box. We don't know what was the training mix, but we had success using it for tasks such as OCR, speech-to-text, object detection, or pointing. It also supports text-only and multimodal function calling, reasoning, code completion and correction.
|
||||
|
||||
Here, we show a few inference examples across different model sizes. You can run them conveniently with [this notebook](https://github.com/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb). We encourage you to try the demos and share them below this blog!
|
||||
|
||||
### Object Detection and Pointing
|
||||
|
||||
### GUI detection
|
||||
|
||||
We test Gemma 4 on GUI element detection and pointing across different sizes, with the following image and text prompt: "What's the bounding box for the "view recipe" element in the image?"
|
||||
|
||||

|
||||
|
||||
With this prompt, the model natively responds in JSON format with the detected bounding boxes - no need for specific instructions or grammar-constrained generation. We found the coordinates refer to an image size of 1000x1000, relative to the input dimensions.
|
||||
|
||||
We visualize the outputs below for your convenience. We parse the bounding boxes from the returned JSON: ```json\n[\n {"box_2d": [171, 75, 245, 308], "label": "view recipe element"}\n]\n```
|
||||
|
||||
| E2B | E4B |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
| 26/A4B | 31B |
|
||||
|  |  |
|
||||
|
||||
### Object Detection
|
||||
|
||||
We test models to detect everyday objects, here we ask them to detect the bike and compare different model outputs. As in the previous case, we parse the bounding box from the json and translate to image space coordinates.
|
||||
|
||||
| E2B | E4B | 26B/A3B | 31B |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
|  |  |  |  |
|
||||
|
||||
### Multimodal Thinking and Function Calling
|
||||
|
||||
We asked Gemma 4 to write HTML code to reconstruct a page we made with Gemini 3. Below you can find the code to do this, we enable thinking and ask each model to generate up to 4000 new tokens, to make it foolproof.
|
||||
|
||||
| Gemini Generated Website (Reference) | Gemini Reproduced Image |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png",
|
||||
},
|
||||
{"type": "text", "text": "Write HTML code for this page."},
|
||||
],
|
||||
}
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=4000)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| E2B | E4B |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
|
||||
| 31B | 26/A4B |
|
||||
| :---- | :---- |
|
||||
|  |  |
|
||||
|
||||
### Video Understanding
|
||||
|
||||
Smaller Gemma 4 models can take in videos with audio while larger ones can take in videos without audio. While the models are not explicitly post-trained on videos, they can understand videos both with and without audio. The model is particularly strong in audios
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"},
|
||||
{"type": "text", "text": "What is happening in the video? What is the song about?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
load_audio_from_video=True, # disable this for larger models
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=200)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B (hallucinated for the audio) | Based on the video provided, here is an analysis: **What is happening in this video?** The video captures a live musical performance, likely at a large outdoor festival or concert. * **Setting:** The scene is a large stage set up outdoors, indicated by the extensive lighting rig, large speakers, and the audience visible in the foreground. * **Performers:** There are several musicians on stage. The focal point seems to be a lead singer/guitarist in a blue shirt and white pants who is actively performing, playing an acoustic or electric guitar. There are also other musicians, including drummers and bassists, visible in the background. * **Atmosphere:** The atmosphere is energetic and dramatic, characterized by bright stage lights, intense blue and white spotlights, and a significant amount of stage fog/smoke, which adds to the concert mood. * **Audience:** The foreground is filled with the backs of audience members, showing they are close to the stage and engaged in watching the performance. **What is the song about?** The lyrics provided below the video are: > "I don't care 'bout street > From what I cherish > Could it's time to protest |
|
||||
| E4B ✅ | The video captures a live performance of a band on a stage with bright lights and smoke effects, surrounded by an audience. The band performs and sings to the crowd. The song appears to be about the struggles and disillusionment of modern life, specifically the feeling of being stuck and the difficulties in finding one's place in a chaotic and overwhelming world. Lyrics mention themes of the "concrete street," "broken promises," "storm of the day," and the difficulty of finding oneself amidst societal pressures. |
|
||||
| 26/A4B (no audio) ✅ | This video shows a live music performance at an outdoor concert or festival at night. A solo male artist, dressed in a blue shirt and white pants, is the central focus. He is playing an acoustic guitar and singing into a microphone on a large, lit stage. The performance is accompanied by a band, including a drummer and other musicians visible in the background and on large video screens. The stage is filled with atmospheric blue lighting, smoke, and bright spotlights. The video captures different perspectives, including shots from the crowd looking up at the stage and close-up views of the performers projected onto large screens. The crowd is visible in the foreground, watching the show. One of the screens also displays the "Estrella Damm" brand name. |
|
||||
| 31B (no audio) ✅ | A musician is performing on stage during a concert, playing an acoustic guitar and singing into a microphone. He is wearing a blue shirt and white pants. The stage is filled with atmospheric smoke and illuminated by bright blue and white lights. Other band members, including a drummer and a keyboardist, are also visible on stage. The video shows the musician from the perspective of the audience, with some shots focusing on him and others showing the large screen on the side of the stage, which displays close-ups of the performer and the crowd. |
|
||||
|
||||
### Captioning
|
||||
|
||||
We have tested all models on captioning. All checkpoints perform very well and accurately capture nuance in complex scenerios.
|
||||
Here's the image we use, with the prompt "Write single detailed caption for this image.".
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png"},
|
||||
{"type": "text", "text": "Write single detailed caption for this image."},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=512)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B ✅ | A medium shot captures a weathered seagull perched atop a stone pedestal in what appears to be a bustling European square, with a grand, classical-style building featuring ornate columns and architectural details dominating the right side of the frame. In the background, a low, light-colored building stretches across the middle ground, flanked by some green foliage, and hints of other structures, including a terracotta-roofed building, are visible. In the foreground, part of a modern bus stop or shelter is visible, with signage indicating routes, and the overall scene is bathed in bright daylight under a pale sky. |
|
||||
| E4B ✅ | A solitary seagull perches atop a grey stone pedestal in a bustling European piazza, with the historical grandeur of a classical building dominating the right side of the frame, contrasting with the more modern, low-rise architecture in the background under a bright, hazy sky. |
|
||||
| 26B/A4B ✅ | A full shot shows a white and brown seagull perched on a gray pillar in Venice, Italy. The seagull faces left, its beak pointed towards the edge of the frame. Behind it is a light gray, two-story building with several dark rectangular windows. To the right of the seagull is a large, white building with columns and intricate carvings. The sky is a clear, light blue. The lighting is bright and even, suggesting it is daytime. |
|
||||
| 31B ✅ | A high-angle, medium shot shows a white and gray seagull perched on a gray concrete pylon in Venice, Italy. The seagull is facing left, its body angled slightly towards the camera. Its white head, neck, and chest contrast with its gray wings and back. The seagull's orange beak is slightly open. The pylon is a light gray concrete cylinder with a weathered texture. In the background, there are several buildings. To the left is a long, low, white building with a series of small, rectangular windows. In the center is a cluster of green trees. To the right is a red building with a tiled roof and several small, rectangular windows. Further right is a large, ornate, white stone building with several columns and sculptures. The sky is a pale blue, and the light is bright and clear. In the lower-left corner, there is a yellow bus with the word "Ferrovia" printed on it in black letters. There are several people walking in the background, but they are small and out of focus. |
|
||||
|
||||
### Audio Question Answering
|
||||
|
||||
These models are trained to answer questions about speech in audio. Music and non-speech sounds were not part of the training data.
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
|
||||
{"type": "text", "text": "Can you describe this audio in detail?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B | This audio is a personal reflection. The speaker is talking about their final farewell address to the nation, which they delivered in Chicago. They express gratitude for the conversations they've had with the American people, noting that despite not having met them face-to-face or even greeted them, these interactions in various settings like living rooms, schools, farms, factory floors, diners, and military outposts have been what has kept them going. |
|
||||
| E4B | The audio is a speech excerpt where a speaker is delivering a farewell address to the nation from Chicago. The speaker reflects on their time in office, expressing gratitude for the conversations they had with the American people across various settings like living rooms, schools, farms, factories, diners, and military outposts. The tone is reflective and appreciative, highlighting the importance of these interactions in their political journey. |
|
||||
|
||||
Here is an example if you want to do transcription:
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
|
||||
{"type": "text", "text": "Transcribe the audio?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B | This week I traveled to Chicago to deliver my final farewell address to the nation following in the tradition of presidents before me It was an opportunity to say thank you whether we've seen eye to eye or rarely agreed at all my conversations with you the American people in living rooms and schools at farms and on factory floors at diners and on distant military outposts all these conversations are what have kept me honest |
|
||||
| E4B | This week I traveled to Chicago to deliver my final farewell address to the nation following in the tradition of presidents before me. It was an opportunity to say thank you. Whether we've seen eye to eye or rarely agreed at all, my conversations with you, the American people, in living rooms and schools, at farms and on factory floors, at diners and on distant military outposts, all these conversations are what have kept me honest. |
|
||||
|
||||
### Multimodal Function Calling
|
||||
|
||||
We test the model by asking to get the weather in the place shown in the image.
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```py
|
||||
import re
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Gets the current weather for a specific location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
tools = [WEATHER_TOOL]
|
||||
messages = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "What is the city in this image? Check the weather there right now."},
|
||||
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg"},
|
||||
]},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tools=[WEATHER_TOOL],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=1000)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model | Output |
|
||||
| :---- | :---- |
|
||||
| E2B | **Analyze the Request:** The user has provided an image and asked two questions:\n * "What is the city in this image?" * "Check the weather there right now."\n\n2. **Analyze the Image:** The image displays a large, ornate, tall structure, characteristic of temple architecture, likely in Southeast Asia (given the style). It appears to be a famous landmark. 3. **Identify the Landmark/Location (Internal Knowledge/Visual Cues):** The structure strongly resembles the chedis/stupas found in places like Bangkok, Thailand (e.g., Wat Arun or similar structures).\n\n4. **Determine the City:** Based on the visual evidence, the location is highly likely to be Bangkok, Thailand.\n\n5. **Check Available Tools:** The available tool is `get_weather(city: STRING)`. 6. **Formulate the Plan:**\n * Step 1: Identify the city from the image (Bangkok).\n * Step 2: Use the `get_weather` tool with the identified city.\n\n7. **Construct the Tool Call:** `get_weather(city="Bangkok")`\n\n8. **Generate Response:** The model should first state the identified city and then call the tool to get the weather. (Self-correction: The model *must* call the tool in the correct format if it needs external information.)call:get_weather{city:Bangkok} |
|
||||
| E4B | '1. **Analyze the Request:** The user has provided an image and asked two things:\n a. Identify the city in the image.\n b. Check the weather in that city right now.\n\n2. **Analyze the Image:** The image shows a very distinctive, ornate tower/spire structure, characteristic of Thai architecture (specifically, a chedi or prang).\n\n3. **Determine City Identification (Internal Knowledge/Visual Recognition):** Based on the architectural style, the structure strongly resembles prominent temples in Thailand, such as Wat Arun in Bangkok.\n\n4. **Evaluate Tool Availability:** I have one available tool: `get_weather(city: STRING)`. This tool requires a city name as input.\n\n5. **Formulate a Plan:**\n a. I need the city name to use the `get_weather` tool.\n b. I must first identify the city from the image.\n c. Since my internal knowledge suggests Bangkok, I will proceed with that assumption, but I must acknowledge that I cannot *definitively* prove the location from the image alone.\n\n6. **Execute Step 1 (City Identification):** Assume the city is Bangkok based on the visual evidence of the architecture.\n\n7. **Execute Step 2 (Tool Call Generation):** Use the `get_weather` tool with "Bangkok" as the `city` argument.\n\n8. **Construct the Response:** Generate the appropriate tool call in JSON format.', 'answer': '<\|tool_call>call:get_weather{city:<\|"\|>Bangkok<\|"\|>}<tool_call\|><eos>' |
|
||||
| 31B | The user wants to know the city in the image and the current weather there. 1. **Analyze the image:** The image shows ornate, spire-like structures (prangs) characteristic of Thai temple architecture. Specifically, these look like the prangs at Wat Phra Kaew (Temple of the Emerald Buddha) or Wat Pho in Bangkok, Thailand. The style is very distinctive of Bangkok's royal temple complexes. 2. **Identify the city:** Based on the architectural style, the city is Bangkok. 3. **Find the weather:** I need to call the `get_weather` tool with `city="Bangkok"`.call:get_weather{city:Bangkok} |
|
||||
| 26/A4B | The image shows several tall, ornate towers (prangs) typical of Thai Buddhist temple architecture. Specifically, these look like the structures within the Grand Palace or Wat Phra Kaew in Bangkok, Thailand. The style is unmistakably Thai. I'll confirm this by looking at the specific towers. These are very iconic.\n2. **Confirm the location:** The image shows elements of the Grand Palace complex in Bangkok.\n3. **Get weather information:** Once the city is identified as Bangkok, I will use the `get_weather` tool with `city='Bangkok'`.call:get_weather{city:Bangkok} |
|
||||
|
||||
# Deploy Anywhere
|
||||
|
||||
Gemma 4 comes with day-0 support for many open-source inference engines, and is ideal for tool calling and agents! We also release ONNX checkpoints that can run on many hardware backends, allowing use cases on edge devices or in browser!
|
||||
|
||||
## transformers
|
||||
|
||||
Gemma 4 comes with first-class transformers support from the get-go 🤗. This integration allows using the model with other libraries like bitsandbytes, PEFT and TRL. Make sure to install the latest version of transformers.
|
||||
|
||||
```bash
|
||||
pip install -U transformers
|
||||
```
|
||||
|
||||
The easiest way to infer with the small Gemma 4 models is through the `any-to-any` pipeline. You can initialize it as follows.
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
pipe = pipeline("any-to-any", model="google/gemma-4-e2b-it")
|
||||
```
|
||||
|
||||
You can then pass in images and text as follows.
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg",
|
||||
},
|
||||
{"type": "text", "text": "Do you have travel advice going to here?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
output = pipe(messages, max_new_tokens=100, return_full_text=False)
|
||||
output[0]["generated_text"]
|
||||
# Based on the image, which appears to show a magnificent, ornate **Buddhist temple or pagoda**, likely in Southeast Asia (such as Thailand, Myanmar, or Cambodia), here is some general travel advice..
|
||||
```
|
||||
|
||||
When inferring with videos, you can include the audio track using the `load_audio_from_video` argument.
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
pipe(messages, load_audio_from_video=True)
|
||||
```
|
||||
|
||||
Going a level lower, you can load Gemma 4 using the `AutoModelForMultimodalLM` class, especially useful for fine-tuning. The built-in chat template takes care of formatting the inputs correctly, please make sure you use it to prevent subtle mistakes when building the prompt manually.
|
||||
|
||||
<details>
|
||||
<summary>Inference code</summary>
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForMultimodalLM, AutoProcessor
|
||||
model = AutoModelForMultimodalLM.from_pretrained("google/gemma-4-E2B-it", device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained("google/gemma-4-E2B-it")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Llama.cpp
|
||||
|
||||
Gemma 4 models come with image+text support in llama.cpp from the get-go! This unlocks using Gemma 4 with all of your favorite local apps: llama-cpp server, lmstudio, Jan as well as coding agents like Pi across many backends such as Metal and CUDA.
|
||||
|
||||
You can install llama-cpp as follows.
|
||||
|
||||
```bash
|
||||
brew install llama.cpp # MacOS
|
||||
winget install llama.cpp # Windows
|
||||
```
|
||||
|
||||
You can then start a server compatible with the OpenAI API Replace the quantization scheme at the end of the command with the precision of your choice.
|
||||
|
||||
```bash
|
||||
llama-server -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
```
|
||||
|
||||
Check out this link [for more](https://huggingface.co/ggml-org/gemma-4-E2B-it-GGUF?local-app=llama.cpp) options on combining llama.cpp with different coding agents and local apps. Find all the GGUF checkpoints [in this collection](https://huggingface.co/collections/ggml-org/gemma-4).
|
||||
|
||||
## Plug in your local agent
|
||||
|
||||
We worked on making sure the new models work locally with agents like **openclaw, hermes, pi, and open code**. All thanks to llama.cpp! Run the following to try Gemma 4 right away.
|
||||
|
||||
First, start your local server:
|
||||
|
||||
```
|
||||
llama-server -hf ggml-org/gemma-4-26b-a4b-it-GGUF:Q4_K_M
|
||||
```
|
||||
|
||||
For **hermes:**
|
||||
|
||||
```shell
|
||||
hermes model
|
||||
```
|
||||
|
||||
For **openclaw:**
|
||||
|
||||
```shell
|
||||
openclaw onboard
|
||||
```
|
||||
|
||||
For **pi** define a `~/.pi/agent/models.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"llama-cpp": {
|
||||
"baseUrl": "http://localhost:8080/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "none",
|
||||
"models": [
|
||||
{
|
||||
"id": "ggml-org-gemma-4-26b-4b-gguf"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For **open code** define a `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"llama.cpp": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "llama-server (local)",
|
||||
"options": {
|
||||
"baseURL": "http://127.0.0.1:8080/v1"
|
||||
},
|
||||
"models": {
|
||||
"gemma-4-26b-4b-it": {
|
||||
"name": "Gemma 4 (local)",
|
||||
"limit": {
|
||||
"context": 128000,
|
||||
"output": 8192
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## transformers.js
|
||||
|
||||
transformers.js enables running Gemma 4 right inside browser. You can check out the model card to see text-only, image & text, audio & text inference in detail [here](https://huggingface.co/onnx-community/gemma-4-E2B-it-ONNX#transformersjs-javascript). We also shipped a demo for you to test the model [here](https://huggingface.co/spaces/webml-community/Gemma-4-WebGPU).
|
||||
|
||||
## MLX
|
||||
|
||||
Full multimodal support of Gemma 4 is available using the open-source [`mlx-vlm` library](https://github.com/Blaizzy/mlx-vlm). Here's how to ask the model to describe an image:
|
||||
|
||||
```shell
|
||||
pip install -U mlx-vlm
|
||||
```
|
||||
|
||||
```shell
|
||||
mlx_vlm.generate \
|
||||
--model google/gemma-4-E4B-it \
|
||||
--image https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg \
|
||||
--prompt "Describe this image in detail"
|
||||
```
|
||||
|
||||
mlx-vlm supports TurboQuant, which delivers the same accuracy as the uncompressed baseline while using ~4x less active memory and running a lot faster end-to-end. This makes long-context inference practical on Apple Silicon without sacrificing quality. Use it like this:
|
||||
|
||||
```shell
|
||||
mlx_vlm.generate \
|
||||
--model "mlx-community/gemma-4-26b-a4b-it-4bit" \
|
||||
--prompt "Your prompt here" \
|
||||
--kv-bits 3.5 \
|
||||
--kv-quant-scheme turboquant
|
||||
```
|
||||
|
||||
For audio examples and more details, please check [the MLX collection](https://hf.co/mlx-community/gemma-4).
|
||||
|
||||
### Mistral.rs
|
||||
|
||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a Rust-native inference engine with day-0 Gemma 4 support across all modalities (text, image, video, audio) and builtin tool-calling and agentic functionality. Install mistral.rs:
|
||||
|
||||
```bash
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/EricLBuehler/mistral.rs/master/install.sh | sh # Linux/macOS
|
||||
|
||||
irm https://raw.githubusercontent.com/EricLBuehler/mistral.rs/master/install.ps1 | iex # Windows
|
||||
```
|
||||
|
||||
You can then start an OpenAI-compatible HTTP server:
|
||||
|
||||
```bash
|
||||
mistralrs serve mistralrs-community/gemma-4-E4B-it-UQFF --from-uqff 8
|
||||
```
|
||||
|
||||
Or, use interactive mode:
|
||||
|
||||
```
|
||||
mistralrs run -m google/gemma-4-E4B-it --isq 8 --image image.png -i "Describe this image in detail."
|
||||
|
||||
mistralrs run -m google/gemma-4-E4B-it --isq 8 --audio audio.mp3 -i "Transcribe this fully."
|
||||
```
|
||||
|
||||
Find all models [here](https://huggingface.co/mistralrs-community/models). Please, follow [the instructions](https://huggingface.co/mistralrs-community/gemma-4-E2B-it-UQFF#install) in the model cards for installation and inference guidelines.
|
||||
|
||||
## Fine-tuning for all
|
||||
|
||||
Gemma 4 models are ideal for fine-tuning in your favorite tools and platforms and at any budget.
|
||||
|
||||
## Fine-tuning with TRL
|
||||
|
||||
Gemma 4 is fully supported for fine-tuning with TRL. To celebrate, TRL has been upgraded with support for multimodal tool responses when interacting with environments, meaning models can now receive images back from tools during training, not just text.
|
||||
|
||||
To showcase this, we've built an example training script where Gemma 4 learns to drive in the CARLA simulator. The model sees the road through a camera, decides what to do and learns from the outcome. After training, it consistently changes lanes to avoid pedestrians. The same approach works for any task where a model needs to see and act: robotics, web browsing, or other interactive environments.
|
||||
|
||||
Get started:
|
||||
|
||||
```shell
|
||||
# pip install git+https://github.com/huggingface/trl.git
|
||||
|
||||
python examples/scripts/openenv/carla_vlm_gemma.py \
|
||||
--env-urls https://sergiopaniego-carla-env.hf.space \
|
||||
https://sergiopaniego-carla-env-2.hf.space \
|
||||
--model google/gemma-4-E2B-it
|
||||
```
|
||||
|
||||
Find the example [here](https://github.com/huggingface/huggingface-gemma-recipes/blob/main/scripts/carla_vlm_gemma.py).
|
||||
|
||||
### Fine-tuning with TRL on Vertex AI
|
||||
|
||||
Additionally, we have prepared an example on how to fine-tune Gemma 4 with TRL on Vertex AI using SFT, to showcase how to extend the function calling capabilities, whilst freezing both the vision and audio towers. The examples include how to build a custom Docker container with latest Transformers, TRL, etc. with CUDA support on Google Cloud, and how to run it via Vertex AI Serverless Training Jobs.
|
||||
|
||||
```python
|
||||
# pip install google-cloud-aiplatform --upgrade --quiet
|
||||
from google.cloud import aiplatform
|
||||
|
||||
aiplatform.init(
|
||||
project="<PROJECT_ID>",
|
||||
location="<LOCATION>",
|
||||
staging_bucket="<BUCKET_URI>",
|
||||
)
|
||||
|
||||
job = aiplatform.CustomContainerTrainingJob(
|
||||
display_name="gemma-4-fine-tuning",
|
||||
container_uri="<CONTAINER_URI>",
|
||||
command=["python", "/gcs/gemma-4-fine-tuning/train.py"],
|
||||
)
|
||||
|
||||
job = job.submit(
|
||||
replica_count=1,
|
||||
machine_type="a3-highgpu-1g",
|
||||
accelerator_type="NVIDIA_H100_80GB",
|
||||
accelerator_count=1,
|
||||
base_output_dir="<BUCKET_URI>/output-dir",
|
||||
environment_variables={
|
||||
"MODEL_ID": "google/gemma-4-E2B-it",
|
||||
"HF_TOKEN": <HF_TOKEN>,
|
||||
},
|
||||
boot_disk_size_gb=500,
|
||||
)
|
||||
```
|
||||
|
||||
You can find the complete example in the "Hugging Face on Google Cloud" docs at https://hf.co/docs/google-cloud/examples/vertex-ai-notebooks-fine-tune-gemma-4.
|
||||
|
||||
## Fine-tuning with Unsloth Studio
|
||||
|
||||
If you want to fine tune and run a Gemma 4 model in a UI, try out [Unsloth Studio](https://unsloth.ai/docs/new/studio). It runs locally or on Google Colab. First, install and start the app:
|
||||
|
||||
```shell
|
||||
# install unsloth studio on MacOS, Linux, WSL
|
||||
curl -fsSL https://unsloth.ai/install.sh | sh
|
||||
|
||||
# install unsloth studio on Windows
|
||||
irm https://unsloth.ai/install.ps1 | iex
|
||||
|
||||
# launch unsloth studio
|
||||
unsloth studio -H 0.0.0.0 -p 8888
|
||||
# Search for for a Gemma 4 model like google/gemma-4-E2B-it
|
||||
```
|
||||
|
||||
Then select any of the Gemma 4 models from the hub.
|
||||
|
||||

|
||||
|
||||
## Try Gemma 4
|
||||
|
||||
We have shipped demos for you to try different Gemma 4 models. We include demos based on transformers implementation for [E4B](https://huggingface.co/spaces/huggingface-projects/gemma-4-e4b-it), [26B/A4B](https://huggingface.co/spaces/huggingface-projects/gemma-4-26b-a4b-it), and dense [31B](https://huggingface.co/spaces/huggingface-projects/gemma-4-31b-it) models, as well as a [WebGPU](https://huggingface.co/spaces/webml-community/Gemma-4-WebGPU) demo with transformers.js 🚀
|
||||
|
||||
|
||||
<iframe width="560" height="315" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/webgpu_demo.mp4" title="WebGPU Demo" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen> </iframe>
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
Gemma 4 models demonstrate exceptional performance across diverse benchmarks, from reasoning and coding to vision and long-context tasks. The graph below shows model performance vs size, with Gemma 4 models forming an impressive Pareto frontier:
|
||||
|
||||
<div style="display: flex; gap: 20px; justify-content: center; align-items: flex-start; flex-wrap: wrap;">
|
||||
<figure style="flex: 1; min-width: 300px; text-align: center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/g4_graph.png" alt="Gemma 4 Performance vs Size" style="width: 100%; height: 400px; object-fit: contain;">
|
||||
</figure>
|
||||
<figure style="flex: 1; min-width: 300px; text-align: center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/g4_graph_2.png" alt="Gemma 4 Arena Elo Score Comparison" style="width: 100%; height: 400px; object-fit: contain;">
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
<p style="text-align: center; font-size: 0.9em; color: #666;">Source: Google (<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/">blog.google</a>)</p>
|
||||
|
||||
Here are detailed benchmark results for the instruction-tuned models:
|
||||
|
||||
| Benchmark | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
|-----------|-------------|-----------------|-------------|-------------|------------------------|
|
||||
| **Reasoning & Knowledge** |
|
||||
| MMLU Pro | [85.2%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-31B-it) | [82.6%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-26B-A4B-it) | [69.4%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-E4B-it) | [60.0%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-E2B-it) | 67.6% |
|
||||
| AIME 2026 no tools | [89.2%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-31B-it) | [88.3%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-26B-A4B-it) | [42.5%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-E4B-it) | [37.5%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-E2B-it) | 20.8% |
|
||||
| GPQA Diamond | [84.3%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-31B-it) | [82.3%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-26B-A4B-it) | [58.6%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-E4B-it) | [43.4%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-E2B-it) | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Coding** |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| HLE no tools | [19.5%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-31B-it) | [8.7%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-26B-A4B-it) | - | - | - |
|
||||
| HLE with search | [26.5%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-31B-it) | [17.2%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-26B-A4B-it) | - | - | - |
|
||||
| **Vision** |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (edit distance) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
Landing Gemma-4 in the open-source ecosystem took a lot of effort from many people and not only the authors of this blog post. In no particular order, we thank many people from the open-source team: Gemma 4 transformers integration is owed to Cyril, Raushan, Eustache, Arthur, Lysandre. We thank Joshua for transformers.js integration and demo, Eric for mistral.rs integration, Son for Llama.cpp, Prince for MLX integration, Quentin, Albert and Kashif for TRL, Adarsh for SGLang transformers backend and Toshihiro for building the demos.
|
||||
This work wouldn't have been possible without Google's extensive contribution with the model artefact, but also the significant effort contributing the model to transformers in an effort to standardize it. The open-source ecosystem is now more complete, with a very capable, freely-licensed, open-source model.
|
||||
The Gemma 4 transformers integration was handled by Cyril, Raushan, Eustache, Arthur, Lysandre. We thank Joshua for the transformers.js integration and demo, Eric for mistral.rs integration, Son for Llama.cpp, Prince for MLX, Quentin for TRL, Adarsh for SGLang transformers backend, and Toshihiro for building several demos.
|
||||
|
||||
This work wouldn't have been possible without Google's extensive contribution with the model artefact, but also their significant effort contributing the model to transformers in an effort to standardize it. The open-source ecosystem is now more complete, with a very capable, freely-licensed, open-source model.
|
||||
@@ -0,0 +1,514 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: image-text-to-text
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-26B-A4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-26B-A4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-26B-A4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,513 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: image-text-to-text
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-26B-A4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-26B-A4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-26B-A4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,513 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: image-text-to-text
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-31B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-31B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-31B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,513 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: image-text-to-text
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-31B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-31B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-31B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,347 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'OBJECT' -%}
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
}
|
||||
{%- elif value is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- if value['required'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- macro format_tool_response_block(tool_name, response) -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if response is mapping -%}
|
||||
{{- 'response:' + tool_name + '{' -}}
|
||||
{%- for key, value in response | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Pre-scan: find last user message index for reasoning guard -#}
|
||||
{%- set ns_turn = namespace(last_user_idx=-1) -%}
|
||||
{%- for i in range(loop_messages | length) -%}
|
||||
{%- if loop_messages[i]['role'] == 'user' -%}
|
||||
{%- set ns_turn.last_user_idx = i -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
|
||||
{%- set prev_nt = namespace(role=None, found=false) -%}
|
||||
{%- if loop.index0 > 0 -%}
|
||||
{%- for j in range(loop.index0 - 1, -1, -1) -%}
|
||||
{%- if not prev_nt.found -%}
|
||||
{%- if loop_messages[j]['role'] != 'tool' -%}
|
||||
{%- set prev_nt.role = loop_messages[j]['role'] -%}
|
||||
{%- set prev_nt.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
|
||||
{%- if not continue_same_model_turn -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render reasoning/reasoning_content as thinking channel -#}
|
||||
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
|
||||
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
|
||||
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set ns_tr_out = namespace(flag=false) -%}
|
||||
{%- if message.get('tool_responses') -%}
|
||||
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endfor -%}
|
||||
{%- elif message.get('tool_calls') -%}
|
||||
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
|
||||
{%- set ns_tool_scan = namespace(stopped=false) -%}
|
||||
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
|
||||
{%- if ns_tool_scan.stopped -%}
|
||||
{%- elif loop_messages[k]['role'] != 'tool' -%}
|
||||
{%- set ns_tool_scan.stopped = true -%}
|
||||
{%- else -%}
|
||||
{%- set follow = loop_messages[k] -%}
|
||||
{#- Resolve tool_call_id to function name -#}
|
||||
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if tc.get('id') == follow.get('tool_call_id') -%}
|
||||
{%- set ns_tname.name = tc['function']['name'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Handle content as string or content-parts array -#}
|
||||
{%- set tool_body = follow.get('content') -%}
|
||||
{%- if tool_body is string -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- elif tool_body is sequence and tool_body is not string -%}
|
||||
{%- set ns_txt = namespace(s='') -%}
|
||||
{%- for part in tool_body -%}
|
||||
{%- if part.get('type') == 'text' -%}
|
||||
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
|
||||
{%- else -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- endif -%}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -0,0 +1,74 @@
|
||||
{
|
||||
"audio_token": "<|audio|>",
|
||||
"backend": "tokenizers",
|
||||
"boa_token": "<|audio>",
|
||||
"boi_token": "<|image>",
|
||||
"bos_token": "<bos>",
|
||||
"eoa_token": "<audio|>",
|
||||
"eoc_token": "<channel|>",
|
||||
"eoi_token": "<image|>",
|
||||
"eos_token": "<eos>",
|
||||
"eot_token": "<turn|>",
|
||||
"escape_token": "<|\"|>",
|
||||
"etc_token": "<tool_call|>",
|
||||
"etd_token": "<tool|>",
|
||||
"etr_token": "<tool_response|>",
|
||||
"extra_special_tokens": [
|
||||
"<|video|>"
|
||||
],
|
||||
"image_token": "<|image|>",
|
||||
"mask_token": "<mask>",
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": "<pad>",
|
||||
"padding_side": "left",
|
||||
"processor_class": "Gemma4Processor",
|
||||
"response_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"const": "assistant"
|
||||
},
|
||||
"thinking": {
|
||||
"type": "string"
|
||||
},
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": "<\\|tool_call>(.*?)<tool_call\\|>",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"const": "function"
|
||||
},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"x-regex": "call\\:(?P<name>\\w+)(?P<arguments>\\{.*\\})",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-parser": "gemma4-tool-call",
|
||||
"additionalProperties": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"x-regex": "(\\<\\|channel\\>thought\\n(?P<thinking>.*?)\\<channel\\|\\>)?(?P<tool_calls>\\<\\|tool_call\\>.*\\<tool_call\\|\\>)?(?P<content>(?:(?!\\<turn\\|\\>)(?!\\<\\|tool_response\\>).)+)?(?:\\<turn\\|\\>|\\<\\|tool_response\\>)?"
|
||||
},
|
||||
"soc_token": "<|channel>",
|
||||
"sot_token": "<|turn>",
|
||||
"stc_token": "<|tool_call>",
|
||||
"std_token": "<|tool>",
|
||||
"str_token": "<|tool_response>",
|
||||
"think_token": "<|think|>",
|
||||
"tokenizer_class": "GemmaTokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
@@ -0,0 +1,516 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: any-to-any
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,513 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: any-to-any
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,513 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: any-to-any
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,515 @@
|
||||
---
|
||||
library_name: transformers
|
||||
license: apache-2.0
|
||||
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
|
||||
pipeline_tag: any-to-any
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
|
||||
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
|
||||
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
|
||||
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
|
||||
<br>
|
||||
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
|
||||
</p>
|
||||
|
||||
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
|
||||
|
||||
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
|
||||
|
||||
Gemma 4 introduces key **capability and architectural advancements**:
|
||||
|
||||
* **Reasoning** – All models in the family are designed as highly capable reasoners, with configurable thinking modes.
|
||||
|
||||
* **Extended Multimodalities** – Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
|
||||
|
||||
* **Diverse & Efficient Architectures** – Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
|
||||
|
||||
* **Optimized for On-Device** – Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
|
||||
|
||||
* **Increased Context Window** – The small models feature a 128K context window, while the medium models support 256K.
|
||||
|
||||
* **Enhanced Coding & Agentic Capabilities** – Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
|
||||
|
||||
* **Native System Prompt Support** – Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
|
||||
|
||||
## **Models Overview**
|
||||
|
||||
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
|
||||
|
||||
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
|
||||
|
||||
### Dense Models
|
||||
|
||||
| Property | E2B | E4B | 31B Dense |
|
||||
| :---- | :---- | :---- | :---- |
|
||||
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
|
||||
| **Layers** | 35 | 42 | 60 |
|
||||
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
|
||||
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
|
||||
| **Vocabulary Size** | 262K | 262K | 262K |
|
||||
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
|
||||
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
|
||||
|
||||
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
|
||||
|
||||
### Mixture-of-Experts (MoE) Model
|
||||
|
||||
| Property | 26B A4B MoE |
|
||||
| :---- | :---- |
|
||||
| **Total Parameters** | 25.2B |
|
||||
| **Active Parameters** | 3.8B |
|
||||
| **Layers** | 30 |
|
||||
| **Sliding Window** | 1024 tokens |
|
||||
| **Context Length** | 256K tokens |
|
||||
| **Vocabulary Size** | 262K |
|
||||
| **Expert Count** | 8 active / 128 total and 1 shared |
|
||||
| **Supported Modalities** | Text, Image |
|
||||
| **Vision Encoder Parameters** | *~550M* |
|
||||
|
||||
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
|
||||
|
||||
## **Benchmark Results**
|
||||
|
||||
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
|
||||
|
||||
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|
||||
| :---- | :---- | :---- | :---- | :---- | :---- |
|
||||
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
|
||||
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
|
||||
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
|
||||
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
|
||||
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
|
||||
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
|
||||
| HLE no tools | 19.5% | 8.7% | - | - | - |
|
||||
| HLE with search | 26.5% | 17.2% | - | - | - |
|
||||
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
|
||||
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
|
||||
| **Vision** | | | | | |
|
||||
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
|
||||
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
|
||||
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
|
||||
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
|
||||
| **Audio** | | | | | |
|
||||
| CoVoST | - | - | 35.54 | 33.47 | - |
|
||||
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
|
||||
| **Long Context** | | | | | |
|
||||
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
|
||||
|
||||
## **Core Capabilities**
|
||||
|
||||
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
|
||||
|
||||
* **Thinking** – Built-in reasoning mode that lets the model think step-by-step before answering.
|
||||
* **Long Context** – Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
|
||||
* **Image Understanding** – Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
|
||||
* **Video Understanding** – Analyze video by processing sequences of frames.
|
||||
* **Interleaved Multimodal Input** – Freely mix text and images in any order within a single prompt.
|
||||
* **Function Calling** – Native support for structured tool use, enabling agentic workflows.
|
||||
* **Coding** – Code generation, completion, and correction.
|
||||
* **Multilingual** – Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
|
||||
* **Audio** (E2B and E4B only) – Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
|
||||
|
||||
`pip install -U transformers torch accelerate`
|
||||
|
||||
Once you have everything installed, you can proceed to load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output:
|
||||
|
||||
```python
|
||||
# Prompt
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a short joke about saving RAM."},
|
||||
]
|
||||
|
||||
# Process input
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = processor(text=text, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=1024)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
|
||||
|
||||
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Audio</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add audio before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
|
||||
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Images</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
|
||||
|
||||
|
||||
`pip install -U transformers torch torchvision accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add image before text
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
|
||||
{"type": "text", "text": "What is shown in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Code for processing Videos</summary>
|
||||
|
||||
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
|
||||
|
||||
`pip install -U transformers torch torchvision librosa accelerate`
|
||||
|
||||
You can then load the model with the code below:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForMultimodalLM
|
||||
|
||||
MODEL_ID = "google/gemma-4-E4B-it"
|
||||
|
||||
# Load model
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
|
||||
|
||||
|
||||
```python
|
||||
# Prompt - add video before text
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
|
||||
{'type': 'text', 'text': 'Describe this video.'}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process input
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
||||
|
||||
# Parse output
|
||||
processor.parse_response(response)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## **Best Practices**
|
||||
|
||||
For the best performance, use these configurations and best practices:
|
||||
|
||||
### 1. Sampling Parameters
|
||||
|
||||
Use the following standardized sampling configuration across all use cases:
|
||||
|
||||
* `temperature=1.0`
|
||||
* `top_p=0.95`
|
||||
* `top_k=64`
|
||||
|
||||
### 2. Thinking Mode Configuration
|
||||
|
||||
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
|
||||
|
||||
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
|
||||
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
|
||||
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
|
||||
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
|
||||
`<|channel>thought\n<channel|>`**[Final answer]**
|
||||
|
||||
> [!Note]
|
||||
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
|
||||
|
||||
### 3. Multi-Turn Conversations
|
||||
|
||||
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
|
||||
|
||||
### 4. Modality order
|
||||
|
||||
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
|
||||
|
||||
### 5. Variable Image Resolution
|
||||
|
||||
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
|
||||
|
||||
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
|
||||
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
|
||||
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
|
||||
|
||||
### 6. Audio
|
||||
|
||||
Use the following prompt structures for audio processing:
|
||||
|
||||
* **Audio Speech Recognition (ASR)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
|
||||
|
||||
Follow these specific instructions for formatting the answer:
|
||||
* Only output the transcription, with no newlines.
|
||||
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
|
||||
```
|
||||
|
||||
* **Automatic Speech Translation (AST)**
|
||||
|
||||
```text
|
||||
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
|
||||
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
|
||||
```
|
||||
|
||||
### 7. Audio and Video Length
|
||||
|
||||
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
|
||||
|
||||
## **Model Data**
|
||||
|
||||
Data used for model training and how the data was processed.
|
||||
|
||||
### **Training Dataset**
|
||||
|
||||
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
|
||||
|
||||
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
|
||||
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
|
||||
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
|
||||
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
|
||||
|
||||
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
|
||||
|
||||
### **Data Preprocessing**
|
||||
|
||||
Here are the key data cleaning and filtering methods applied to the training data:
|
||||
|
||||
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
|
||||
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
|
||||
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
|
||||
|
||||
## **Ethics and Safety**
|
||||
|
||||
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
|
||||
|
||||
### **Evaluation Approach**
|
||||
|
||||
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Google’s AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
|
||||
|
||||
* Content related to child sexual abuse material and exploitation
|
||||
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
|
||||
* Sexually explicit content
|
||||
* Hate speech (e.g., dehumanizing members of protected groups)
|
||||
* Harassment (e.g., encouraging violence against people)
|
||||
|
||||
### **Evaluation Results**
|
||||
|
||||
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
|
||||
|
||||
## **Usage and Limitations**
|
||||
|
||||
These models have certain limitations that users should be aware of.
|
||||
|
||||
### **Intended Usage**
|
||||
|
||||
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
|
||||
|
||||
* **Content Creation and Communication**
|
||||
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
|
||||
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
|
||||
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
|
||||
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
|
||||
* **Research and Education**
|
||||
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
|
||||
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
|
||||
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
|
||||
|
||||
### **Limitations**
|
||||
|
||||
* **Training Data**
|
||||
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
|
||||
* The scope of the training dataset determines the subject areas the model can handle effectively.
|
||||
* **Context and Task Complexity**
|
||||
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
|
||||
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
|
||||
* **Language Ambiguity and Nuance**
|
||||
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
|
||||
* **Factual Accuracy**
|
||||
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
|
||||
* **Common Sense**
|
||||
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
|
||||
|
||||
### **Ethical Considerations and Risks**
|
||||
|
||||
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
|
||||
|
||||
* **Bias and Fairness**
|
||||
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
|
||||
* **Misinformation and Misuse**
|
||||
* VLMs can be misused to generate text that is false, misleading, or harmful.
|
||||
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
|
||||
* **Transparency and Accountability**
|
||||
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
|
||||
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
|
||||
|
||||
**Risks identified and mitigations**:
|
||||
|
||||
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
|
||||
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
|
||||
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
|
||||
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
|
||||
|
||||
### **Benefits**
|
||||
|
||||
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
|
||||
@@ -0,0 +1,344 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'OBJECT' -%}
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
}
|
||||
{%- elif value is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- if value['required'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- macro format_tool_response_block(tool_name, response) -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if response is mapping -%}
|
||||
{{- 'response:' + tool_name + '{' -}}
|
||||
{%- for key, value in response | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Pre-scan: find last user message index for reasoning guard -#}
|
||||
{%- set ns_turn = namespace(last_user_idx=-1) -%}
|
||||
{%- for i in range(loop_messages | length) -%}
|
||||
{%- if loop_messages[i]['role'] == 'user' -%}
|
||||
{%- set ns_turn.last_user_idx = i -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
|
||||
{%- set prev_nt = namespace(role=None, found=false) -%}
|
||||
{%- if loop.index0 > 0 -%}
|
||||
{%- for j in range(loop.index0 - 1, -1, -1) -%}
|
||||
{%- if not prev_nt.found -%}
|
||||
{%- if loop_messages[j]['role'] != 'tool' -%}
|
||||
{%- set prev_nt.role = loop_messages[j]['role'] -%}
|
||||
{%- set prev_nt.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
|
||||
{%- if not continue_same_model_turn -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render reasoning/reasoning_content as thinking channel -#}
|
||||
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
|
||||
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
|
||||
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set ns_tr_out = namespace(flag=false) -%}
|
||||
{%- if message.get('tool_responses') -%}
|
||||
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endfor -%}
|
||||
{%- elif message.get('tool_calls') -%}
|
||||
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
|
||||
{%- set ns_tool_scan = namespace(stopped=false) -%}
|
||||
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
|
||||
{%- if ns_tool_scan.stopped -%}
|
||||
{%- elif loop_messages[k]['role'] != 'tool' -%}
|
||||
{%- set ns_tool_scan.stopped = true -%}
|
||||
{%- else -%}
|
||||
{%- set follow = loop_messages[k] -%}
|
||||
{#- Resolve tool_call_id to function name -#}
|
||||
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if tc.get('id') == follow.get('tool_call_id') -%}
|
||||
{%- set ns_tname.name = tc['function']['name'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Handle content as string or content-parts array -#}
|
||||
{%- set tool_body = follow.get('content') -%}
|
||||
{%- if tool_body is string -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- elif tool_body is sequence and tool_body is not string -%}
|
||||
{%- set ns_txt = namespace(s='') -%}
|
||||
{%- for part in tool_body -%}
|
||||
{%- if part.get('type') == 'text' -%}
|
||||
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
|
||||
{%- else -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- endif -%}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -0,0 +1,74 @@
|
||||
{
|
||||
"audio_token": "<|audio|>",
|
||||
"backend": "tokenizers",
|
||||
"boa_token": "<|audio>",
|
||||
"boi_token": "<|image>",
|
||||
"bos_token": "<bos>",
|
||||
"eoa_token": "<audio|>",
|
||||
"eoc_token": "<channel|>",
|
||||
"eoi_token": "<image|>",
|
||||
"eos_token": "<eos>",
|
||||
"eot_token": "<turn|>",
|
||||
"escape_token": "<|\"|>",
|
||||
"etc_token": "<tool_call|>",
|
||||
"etd_token": "<tool|>",
|
||||
"etr_token": "<tool_response|>",
|
||||
"extra_special_tokens": [
|
||||
"<|video|>"
|
||||
],
|
||||
"image_token": "<|image|>",
|
||||
"mask_token": "<mask>",
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": "<pad>",
|
||||
"padding_side": "left",
|
||||
"processor_class": "Gemma4Processor",
|
||||
"response_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"const": "assistant"
|
||||
},
|
||||
"thinking": {
|
||||
"type": "string"
|
||||
},
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": "<\\|tool_call>(.*?)<tool_call\\|>",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"const": "function"
|
||||
},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"x-regex": "call\\:(?P<name>\\w+)(?P<arguments>\\{.*\\})",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-parser": "gemma4-tool-call",
|
||||
"additionalProperties": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"x-regex": "(\\<\\|channel\\>thought\\n(?P<thinking>.*?)\\<channel\\|\\>)?(?P<tool_calls>\\<\\|tool_call\\>.*\\<tool_call\\|\\>)?(?P<content>(?:(?!\\<turn\\|\\>)(?!\\<\\|tool_response\\>).)+)?(?:\\<turn\\|\\>|\\<\\|tool_response\\>)?"
|
||||
},
|
||||
"soc_token": "<|channel>",
|
||||
"sot_token": "<|turn>",
|
||||
"stc_token": "<|tool_call>",
|
||||
"std_token": "<|tool>",
|
||||
"str_token": "<|tool_response>",
|
||||
"think_token": "<|think|>",
|
||||
"tokenizer_class": "GemmaTokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
# Gemma4_(E2B)-Multimodal.ipynb — extracted cells
|
||||
# Source: https://github.com/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb
|
||||
|
||||
# ===== CELL 0 (markdown) =====
|
||||
# This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers.
|
||||
|
||||
# ===== CELL 1 (code) =====
|
||||
!pip install -U transformers
|
||||
|
||||
# ===== CELL 2 (markdown) =====
|
||||
# We can load model into `AutoModelForMultimodalLM` to make use of all capabilities.
|
||||
|
||||
# ===== CELL 3 (code) =====
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoModelForMultimodalLM, AutoProcessor
|
||||
#model_list = ["google/gemma-4-26B-A4B-it", "google/gemma-4-E4B-it",
|
||||
# "google/gemma-4-E2B-it", "google/gemma-4-31B-it"]
|
||||
model_id = "google/gemma-4-E2B-it"
|
||||
model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
# ===== CELL 4 (markdown) =====
|
||||
# ## Code completion
|
||||
|
||||
# ===== CELL 5 (markdown) =====
|
||||
# We give Gemma-4 a website screenshot to reproduce the code.
|
||||
|
||||
# ===== CELL 6 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png",
|
||||
},
|
||||
{"type": "text", "text": "Write HTML code for this page."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=4000)
|
||||
|
||||
# ===== CELL 7 (code) =====
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 8 (markdown) =====
|
||||
# ## Video Inference
|
||||
|
||||
# ===== CELL 9 (markdown) =====
|
||||
# We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`.
|
||||
|
||||
# ===== CELL 10 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"},
|
||||
{"type": "text", "text": "What is happening in the video? What is the song about?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
load_audio_from_video=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=200)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
|
||||
# ===== CELL 11 (code) =====
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 12 (markdown) =====
|
||||
# ## Multimodal Function Calling
|
||||
|
||||
# ===== CELL 13 (code) =====
|
||||
import re
|
||||
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Gets the current weather for a specific location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
tools = [WEATHER_TOOL]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg"},
|
||||
{"type": "text", "text": "What is the city in this image? Check the weather there right now."},
|
||||
]},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tools=[WEATHER_TOOL],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
|
||||
# ===== CELL 14 (code) =====
|
||||
output = model.generate(**inputs, max_new_tokens=1000)
|
||||
|
||||
# ===== CELL 15 (code) =====
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
|
||||
# ===== CELL 16 (code) =====
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 17 (markdown) =====
|
||||
# # Any-to-any inference
|
||||
|
||||
# ===== CELL 18 (markdown) =====
|
||||
# We can also run the model with `any-to-any` pipeline.
|
||||
|
||||
# ===== CELL 19 (code) =====
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("any-to-any", model="google/gemma-4-e2b-it")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# ===== CELL 20 (code) =====
|
||||
pipe(messages)#, load_audio_from_video=True)
|
||||
|
||||
# ===== CELL 21 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
# ===== CELL 22 (markdown) =====
|
||||
# # Object detection and pointing
|
||||
|
||||
# ===== CELL 23 (code) =====
|
||||
import re
|
||||
import torch
|
||||
from transformers.image_utils import load_image
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
import json
|
||||
|
||||
# ===== CELL 24 (code) =====
|
||||
image_url = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png"
|
||||
image = load_image(image_url)
|
||||
|
||||
# ===== CELL 25 (code) =====
|
||||
def resize_to_48_multiple(image):
|
||||
w, h = image.size
|
||||
new_w = (w // 48) * 48
|
||||
new_h = (h // 48) * 48
|
||||
return image.crop((0, 0, new_w, new_h))
|
||||
|
||||
# ===== CELL 26 (code) =====
|
||||
def inputs_for_object_detection(image, what_object):
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": f"What's the bounding box for the {what_object} in the image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
return inputs.to(model.device)
|
||||
|
||||
# ===== CELL 27 (code) =====
|
||||
def extract_json(text: str):
|
||||
text = text.strip()
|
||||
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"\s*```$", "", text)
|
||||
|
||||
# Try direct parse first
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fallback: extract first JSON object or array
|
||||
match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL)
|
||||
if match:
|
||||
candidate = match.group(1)
|
||||
return json.loads(candidate)
|
||||
|
||||
raise ValueError("No valid JSON found")
|
||||
|
||||
# ===== CELL 28 (code) =====
|
||||
def detect_object(image_url, what_object):
|
||||
image = load_image(image_url)
|
||||
image = resize_to_48_multiple(image)
|
||||
inputs = inputs_for_object_detection(image, what_object)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
|
||||
generated = processor.decode(generated_outputs[0, input_len:])
|
||||
parsed_json = extract_json(generated)[0]
|
||||
return parsed_json
|
||||
|
||||
# ===== CELL 29 (code) =====
|
||||
def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):
|
||||
dpi = 72
|
||||
width, height = image.size
|
||||
fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})
|
||||
|
||||
ax.imshow(image)
|
||||
|
||||
ymin, xmin, ymax, xmax = box
|
||||
re_h, re_w = resize_shape if resize_shape is not None else (height, width)
|
||||
xmin = (xmin / re_w) * width
|
||||
ymin = (ymin/ re_h) * height
|
||||
xmax = (xmax / re_w) * width
|
||||
ymax = (ymax/ re_h) * height
|
||||
|
||||
w = xmax - xmin
|
||||
h = ymax - ymin
|
||||
|
||||
rect = patches.Rectangle(
|
||||
(xmin, ymin),
|
||||
w,
|
||||
h,
|
||||
linewidth=10,
|
||||
edgecolor="green",
|
||||
facecolor="none"
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
if label is not None:
|
||||
ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor="yellow", alpha=0.5))
|
||||
|
||||
plt.axis("off")
|
||||
plt.savefig(f"boxes_{i}.png")
|
||||
plt.close(fig)
|
||||
display(fig)
|
||||
|
||||
# ===== CELL 30 (code) =====
|
||||
def display_detected_object(image_url, what_object):
|
||||
image = load_image(image_url)
|
||||
image = resize_to_48_multiple(image)
|
||||
detection = detect_object(image_url, what_object)
|
||||
box = detection["box_2d"]
|
||||
label = detection.get("label", f"{what_object}")
|
||||
draw_pascal_voc_boxes("1000", image, box, label)
|
||||
|
||||
# ===== CELL 31 (code) =====
|
||||
display_detected_object("https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png", "bike")
|
||||
|
||||
# ===== CELL 32 (markdown) =====
|
||||
# ## Captioning
|
||||
|
||||
# ===== CELL 33 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png"},
|
||||
{"type": "text", "text": "Write single detailed caption for this image."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=512)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 34 (markdown) =====
|
||||
# ## Audio Understanding
|
||||
|
||||
# ===== CELL 35 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
|
||||
{"type": "text", "text": "Can you describe this audio in detail?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
@@ -0,0 +1,595 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install -U transformers"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can load model into `AutoModelForMultimodalLM` to make use of all capabilities."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForMultimodalLM, AutoProcessor\n",
|
||||
"#model_list = [\"google/gemma-4-26B-A4B-it\", \"google/gemma-4-E4B-it\",\n",
|
||||
"# \"google/gemma-4-E2B-it\", \"google/gemma-4-31B-it\"]\n",
|
||||
"model_id = \"google/gemma-4-E2B-it\"\n",
|
||||
"model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map=\"auto\")\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_id)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Code completion"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We give Gemma-4 a website screenshot to reproduce the code."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"Write HTML code for this page.\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" enable_thinking=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=4000)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n",
|
||||
"\n",
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Video Inference"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"video\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in the video? What is the song about?\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" load_audio_from_video=True,\n",
|
||||
").to(model.device)\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=200)\n",
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Multimodal Function Calling"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"WEATHER_TOOL = {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"name\": \"get_weather\",\n",
|
||||
" \"description\": \"Gets the current weather for a specific location.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"city\": {\"type\": \"string\", \"description\": \"The city name\"},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"city\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"tools = [WEATHER_TOOL]\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is the city in this image? Check the weather there right now.\"},\n",
|
||||
" ]},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tools=[WEATHER_TOOL],\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" enable_thinking=True,\n",
|
||||
").to(model.device)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"output = model.generate(**inputs, max_new_tokens=1000)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Any-to-any inference"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also run the model with `any-to-any` pipeline."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"any-to-any\", model=\"google/gemma-4-e2b-it\")\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"video\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"pipe(messages)#, load_audio_from_video=True)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"video\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"inputs = inputs.to(model.device)\n",
|
||||
"\n",
|
||||
"generated_ids = model.generate(**inputs, max_new_tokens=128)\n",
|
||||
"generated_ids_trimmed = [\n",
|
||||
" out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
|
||||
"]\n",
|
||||
"output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
")\n",
|
||||
"print(output_text)\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Object detection and pointing"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import torch\n",
|
||||
"from transformers.image_utils import load_image\n",
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.patches as patches\n",
|
||||
"import json"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"image_url = \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\"\n",
|
||||
"image = load_image(image_url)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def resize_to_48_multiple(image):\n",
|
||||
" w, h = image.size\n",
|
||||
" new_w = (w // 48) * 48\n",
|
||||
" new_h = (h // 48) * 48\n",
|
||||
" return image.crop((0, 0, new_w, new_h))"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def inputs_for_object_detection(image, what_object):\n",
|
||||
" messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": image},\n",
|
||||
" {\"type\": \"text\", \"text\": f\"What's the bounding box for the {what_object} in the image?\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" enable_thinking=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return inputs.to(model.device)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def extract_json(text: str):\n",
|
||||
" text = text.strip()\n",
|
||||
"\n",
|
||||
" text = re.sub(r\"^```(?:json)?\\s*\", \"\", text)\n",
|
||||
" text = re.sub(r\"\\s*```$\", \"\", text)\n",
|
||||
"\n",
|
||||
" # Try direct parse first\n",
|
||||
" try:\n",
|
||||
" return json.loads(text)\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" # Fallback: extract first JSON object or array\n",
|
||||
" match = re.search(r'(\\{.*\\}|\\[.*\\])', text, re.DOTALL)\n",
|
||||
" if match:\n",
|
||||
" candidate = match.group(1)\n",
|
||||
" return json.loads(candidate)\n",
|
||||
"\n",
|
||||
" raise ValueError(\"No valid JSON found\")"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def detect_object(image_url, what_object):\n",
|
||||
" image = load_image(image_url)\n",
|
||||
" image = resize_to_48_multiple(image)\n",
|
||||
" inputs = inputs_for_object_detection(image, what_object)\n",
|
||||
" input_len = inputs[\"input_ids\"].shape[-1]\n",
|
||||
" generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)\n",
|
||||
" generated = processor.decode(generated_outputs[0, input_len:])\n",
|
||||
" parsed_json = extract_json(generated)[0]\n",
|
||||
" return parsed_json"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):\n",
|
||||
" dpi = 72\n",
|
||||
" width, height = image.size\n",
|
||||
" fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})\n",
|
||||
"\n",
|
||||
" ax.imshow(image)\n",
|
||||
"\n",
|
||||
" ymin, xmin, ymax, xmax = box\n",
|
||||
" re_h, re_w = resize_shape if resize_shape is not None else (height, width)\n",
|
||||
" xmin = (xmin / re_w) * width\n",
|
||||
" ymin = (ymin/ re_h) * height\n",
|
||||
" xmax = (xmax / re_w) * width\n",
|
||||
" ymax = (ymax/ re_h) * height\n",
|
||||
"\n",
|
||||
" w = xmax - xmin\n",
|
||||
" h = ymax - ymin\n",
|
||||
"\n",
|
||||
" rect = patches.Rectangle(\n",
|
||||
" (xmin, ymin),\n",
|
||||
" w,\n",
|
||||
" h,\n",
|
||||
" linewidth=10,\n",
|
||||
" edgecolor=\"green\",\n",
|
||||
" facecolor=\"none\"\n",
|
||||
" )\n",
|
||||
" ax.add_patch(rect)\n",
|
||||
"\n",
|
||||
" if label is not None:\n",
|
||||
" ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor=\"yellow\", alpha=0.5))\n",
|
||||
"\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
" plt.savefig(f\"boxes_{i}.png\")\n",
|
||||
" plt.close(fig)\n",
|
||||
" display(fig)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def display_detected_object(image_url, what_object):\n",
|
||||
" image = load_image(image_url)\n",
|
||||
" image = resize_to_48_multiple(image)\n",
|
||||
" detection = detect_object(image_url, what_object)\n",
|
||||
" box = detection[\"box_2d\"]\n",
|
||||
" label = detection.get(\"label\", f\"{what_object}\")\n",
|
||||
" draw_pascal_voc_boxes(\"1000\", image, box, label)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"display_detected_object(\"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\", \"bike\")"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"##\u00a0Captioning"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"Write single detailed caption for this image.\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=512)\n",
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n",
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Audio Understanding"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"audio\", \"url\": \"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"Can you describe this audio in detail?\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(\n",
|
||||
" **inputs,\n",
|
||||
" max_new_tokens=1000,\n",
|
||||
" do_sample=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(processor.decode(output[0], skip_special_tokens=True))\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
|
||||
"""
|
||||
GRPO training with OpenEnv's CARLA environment for VLMs (Vision Language Models).
|
||||
|
||||
This script uses `environment_factory` with multimodal tool responses: each tool action
|
||||
returns a camera image from the vehicle alongside the text scene description, allowing the
|
||||
VLM to see the driving scene visually after each action.
|
||||
|
||||
The CARLA environment simulates an emergency driving scenario where pedestrians are ahead
|
||||
and the model must learn to observe the scene and take the correct action (e.g., swerve
|
||||
to an empty lane) to minimize casualties.
|
||||
|
||||
Setup:
|
||||
```sh
|
||||
pip install "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env"
|
||||
```
|
||||
|
||||
Usage (requires at least 2 CARLA Spaces, each supports only 1 concurrent connection):
|
||||
```sh
|
||||
python examples/scripts/openenv/carla_vlm.py \
|
||||
--env-urls https://server1.hf.space https://server2.hf.space
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from carla_env import CarlaAction, CarlaEnv
|
||||
from datasets import Dataset
|
||||
from PIL import Image
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.")
|
||||
parser.add_argument("--model", type=str, default="google/gemma-4-E2B-it")
|
||||
parser.add_argument(
|
||||
"--env-urls",
|
||||
type=str,
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).",
|
||||
)
|
||||
parser.add_argument("--dataset-size", type=int, default=1000)
|
||||
parser.add_argument("--max-completion-length", type=int, default=3072)
|
||||
parser.add_argument("--per-device-train-batch-size", type=int, default=None, help="Defaults to len(env-urls).")
|
||||
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
||||
parser.add_argument("--max-steps", type=int, default=100)
|
||||
parser.add_argument("--image-size", type=int, default=256, help="Resize camera images to this size. 0 to disable.")
|
||||
parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.")
|
||||
parser.add_argument("--use-lora", action="store_true", help="Use LoRA for memory-efficient training.")
|
||||
parser.add_argument("--lora-r", type=int, default=128, help="LoRA rank.")
|
||||
parser.add_argument("--lora-alpha", type=int, default=256, help="LoRA alpha.")
|
||||
parser.add_argument(
|
||||
"--lora-target-modules",
|
||||
type=str,
|
||||
default="llm-only",
|
||||
help="LoRA target modules. Use 'llm-only' to skip vision encoder, 'all-linear' for all.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default=5e-6, help="Learning rate. Default 5e-6 (good for LoRA r=128)."
|
||||
)
|
||||
parser.add_argument("--hub-model-id", type=str, default=None)
|
||||
parser.add_argument("--hub-private-repo", action="store_true", help="Make the Hub repo private.")
|
||||
parser.add_argument("--run-name", type=str, default=None)
|
||||
parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
SIM_TICKS = 10
|
||||
|
||||
|
||||
def reward_func(completions, environments, **kwargs):
|
||||
rewards = []
|
||||
for i, (comp, env) in enumerate(zip(completions, environments, strict=False)):
|
||||
# Advance to episode end to capture rubric_reward
|
||||
try:
|
||||
final = env._advance_until_done()
|
||||
if final and final.observation.rubric_reward:
|
||||
env.reward = final.observation.rubric_reward
|
||||
except Exception as e:
|
||||
print(f"[WARN] _advance_until_done failed for gen={i}: {e}")
|
||||
r = env.reward
|
||||
tools = [
|
||||
msg["tool_calls"][0]["function"]["name"] for msg in comp if isinstance(msg, dict) and msg.get("tool_calls")
|
||||
]
|
||||
rewards.append(r)
|
||||
print(f"[DEBUG reward] gen={i} tools={tools} env_reward={env.reward} total={r}")
|
||||
return rewards
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
env_url_iter = iter(args.env_urls)
|
||||
image_size = args.image_size
|
||||
|
||||
prompt = """\
|
||||
You control an autonomous vehicle in an emergency. There are pedestrians ahead and you must \
|
||||
decide what to do immediately.
|
||||
|
||||
You will see a camera image from the vehicle after each action. Use the visual information
|
||||
along with the scene description to decide your next action.
|
||||
|
||||
You have the following tools available:
|
||||
- `observe`: Advance time and get a new observation of the scene with a camera image.
|
||||
- `emergency_stop`: Apply maximum braking to stop the vehicle.
|
||||
- `lane_change(direction)`: Change lane to the left or right. Direction must be "left" or "right".
|
||||
|
||||
Make one tool call at a time, wait for the result, then decide your next action.
|
||||
Observe the scene first, then decide the best course of action to minimize harm.
|
||||
Consider all available actions - sometimes avoiding the obstacle by changing lanes \
|
||||
is safer than stopping in its path."""
|
||||
|
||||
dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]})
|
||||
|
||||
class CarlaVLMEnv:
|
||||
def __init__(self):
|
||||
self.url = next(env_url_iter)
|
||||
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
|
||||
self.reward = 0.0
|
||||
|
||||
@staticmethod
|
||||
def _describe(obs) -> str:
|
||||
parts = []
|
||||
parts.append(f"Speed: {obs.speed_kmh:.1f} km/h.")
|
||||
if obs.nearby_actors:
|
||||
for actor in obs.nearby_actors:
|
||||
parts.append(f"- {actor.get('type', 'actor')} at {actor.get('distance', '?')}m")
|
||||
else:
|
||||
parts.append("No nearby actors detected.")
|
||||
if obs.collision_detected:
|
||||
parts.append(f"COLLISION detected with {obs.collided_with or 'unknown'}!")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _decode_image(camera_image_b64, target_size):
|
||||
"""Decode base64 JPEG image and optionally resize."""
|
||||
img_bytes = base64.b64decode(camera_image_b64)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
if target_size > 0:
|
||||
img.thumbnail((target_size, target_size), Image.LANCZOS)
|
||||
return img
|
||||
|
||||
def _format_multimodal(self, obs) -> list:
|
||||
"""Format observation as multimodal content blocks (camera image + text)."""
|
||||
content = []
|
||||
if obs.camera_image is not None:
|
||||
img = self._decode_image(obs.camera_image, image_size)
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": self._describe(obs)})
|
||||
return content
|
||||
|
||||
def _advance(self, ticks: int = SIM_TICKS):
|
||||
result = None
|
||||
for _ in range(ticks):
|
||||
result = self.client.step(CarlaAction(action_type="observe"))
|
||||
if result.done:
|
||||
break
|
||||
return result
|
||||
|
||||
def _advance_until_done(self, max_ticks: int = 50):
|
||||
"""Advance the simulation until the episode ends."""
|
||||
result = None
|
||||
for _ in range(max_ticks):
|
||||
result = self.client.step(CarlaAction(action_type="observe"))
|
||||
if result.done:
|
||||
break
|
||||
return result
|
||||
|
||||
def _advance_and_capture(self, ticks: int = SIM_TICKS):
|
||||
"""Advance the simulation, then capture an image of the current state."""
|
||||
result = self._advance(ticks)
|
||||
capture_result = self.client.step(CarlaAction(action_type="capture_image"))
|
||||
result.observation.camera_image = capture_result.observation.camera_image
|
||||
return result
|
||||
|
||||
def reset(self, **kwargs) -> str | None:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = self.client.reset(scenario_name="trolley_micro_escape_exists")
|
||||
self.reward = 0.0
|
||||
return self._describe(result.observation)
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
raise
|
||||
print(f"[WARN] reset failed (attempt {attempt + 1}/3): {e}. Reconnecting...")
|
||||
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
|
||||
|
||||
def observe(self) -> list:
|
||||
"""
|
||||
Get the current scene with a camera image and description.
|
||||
|
||||
Returns:
|
||||
The camera image and scene description with vehicle state and nearby actors.
|
||||
"""
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
def emergency_stop(self) -> list:
|
||||
"""
|
||||
Apply maximum braking to stop the vehicle.
|
||||
|
||||
Returns:
|
||||
The camera image and scene description after braking.
|
||||
"""
|
||||
self.client.step(CarlaAction(action_type="emergency_stop"))
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
print(f"[DEBUG env] emergency_stop: done={result.done}, reward={self.reward}")
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
def lane_change(self, direction: str) -> list:
|
||||
"""
|
||||
Change lane to avoid obstacles.
|
||||
|
||||
Args:
|
||||
direction: Direction to change lane, either "left" or "right".
|
||||
|
||||
Returns:
|
||||
The camera image and scene description after changing lane.
|
||||
"""
|
||||
self.client.step(CarlaAction(action_type="lane_change", lane_direction=direction))
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
print(f"[DEBUG env] lane_change({direction}): done={result.done}, reward={self.reward}")
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
peft_config = None
|
||||
if args.use_lora:
|
||||
from peft import LoraConfig
|
||||
|
||||
if args.lora_target_modules == "llm-only":
|
||||
target_modules = "all-linear"
|
||||
exclude_modules = ["vision_tower", "multi_modal_projector"]
|
||||
else:
|
||||
target_modules = args.lora_target_modules
|
||||
exclude_modules = None
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=target_modules,
|
||||
exclude_modules=exclude_modules,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=args.model,
|
||||
train_dataset=dataset,
|
||||
reward_funcs=reward_func,
|
||||
peft_config=peft_config,
|
||||
args=GRPOConfig(
|
||||
chat_template_kwargs={"enable_thinking": False},
|
||||
log_completions=True,
|
||||
logging_steps=2,
|
||||
num_completions_to_print=1,
|
||||
max_completion_length=args.max_completion_length,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size or len(args.env_urls),
|
||||
steps_per_generation=1,
|
||||
num_generations=len(args.env_urls),
|
||||
max_tool_calling_iterations=10,
|
||||
learning_rate=args.learning_rate,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
max_steps=args.max_steps,
|
||||
push_to_hub=args.hub_model_id is not None,
|
||||
hub_model_id=args.hub_model_id,
|
||||
hub_private_repo=args.hub_private_repo,
|
||||
run_name=args.run_name,
|
||||
report_to=args.report_to,
|
||||
trackio_space_id=args.trackio_space_id,
|
||||
),
|
||||
environment_factory=CarlaVLMEnv,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
|
||||
|
||||
|
||||
def collate_fn(examples, processor):
|
||||
messages = list()
|
||||
for sample in examples:
|
||||
audio = sample["audio"]["array"]
|
||||
label = str(sample["text"])
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": audio},
|
||||
{"type": "text", "text": "Please transcribe this audio."},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
batch = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask the tokens that we do not want to include in the loss computation
|
||||
# -100 is ignored during categorical cross entropy loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == processor.tokenizer.audio_token_id] = -100
|
||||
labels[labels == processor.tokenizer.image_token_id] = -100
|
||||
labels[labels == processor.tokenizer.boi_token_id] = -100
|
||||
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def freeze_layers(model):
|
||||
for name, param in model.named_parameters():
|
||||
if "attn" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def run_inference(val_dataset, processor, model, fname):
|
||||
# infer before training
|
||||
val_sample = random.choice(val_dataset)
|
||||
audio = val_sample["audio"]["array"]
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": audio},
|
||||
{"type": "text", "text": "Please transcribe this audio."},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
with torch.no_grad():
|
||||
generation = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
|
||||
print(f"Audio transcription: {decoded}")
|
||||
print(f"Label: {val_sample['text']}")
|
||||
|
||||
|
||||
def main():
|
||||
model_id = "google/gemma-3n-E2B-it"
|
||||
processor = Gemma3nProcessor.from_pretrained(model_id)
|
||||
|
||||
# Load and split the dataset.
|
||||
ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
|
||||
split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = split_ds["train"].select(range(10000))
|
||||
val_dataset = split_ds["test"].select(range(100))
|
||||
|
||||
# create data loader
|
||||
partial_collate_fn = partial(collate_fn, processor=processor)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
)
|
||||
|
||||
# load the model and optimizer
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to(
|
||||
"cuda", dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
run_inference(val_dataset, processor, model, "pred_before.png")
|
||||
|
||||
model = freeze_layers(model)
|
||||
|
||||
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
|
||||
|
||||
# Start Training
|
||||
accumulation_steps = 8
|
||||
for idx, batch in tqdm(enumerate(train_dataloader)):
|
||||
outputs = model(**batch.to(model.device, dtype=torch.bfloat16))
|
||||
loss = outputs.loss / accumulation_steps
|
||||
if idx % 100 == 0:
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
count = 0
|
||||
for val_batch in tqdm(val_dataloader, desc="Validation"):
|
||||
val_loss = (
|
||||
val_loss
|
||||
+ model(**val_batch.to(model.device, dtype=torch.bfloat16)).loss
|
||||
)
|
||||
count = count + 1
|
||||
val_loss = val_loss / count
|
||||
print(
|
||||
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
|
||||
)
|
||||
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
|
||||
|
||||
loss.backward()
|
||||
if idx % 8 == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Train Gemma-3n on various vision-language datasets including intersection-dataset.
|
||||
|
||||
For Gemma-3n with intersection dataset:
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
sft_vlm_gemma3n.py \
|
||||
--dataset_name ariG23498/intersection-dataset \
|
||||
--model_name_or_path google/gemma-3n-E2B-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3n-E2B-it-trl-sft-intersection \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
|
||||
Train Gemma-3n on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
sft_vlm_gemma3n.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
|
||||
Train Gemma-3n on the FanqingM/MMIU-Benchmark dataset (multi-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
sft_vlm_gemma3n.py \
|
||||
--dataset_name FanqingM/MMIU-Benchmark \
|
||||
--dataset_train_split test \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear
|
||||
--attn_implementation eager
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForImageTextToText, AutoProcessor,
|
||||
Gemma3nForConditionalGeneration)
|
||||
from trl import (ModelConfig, ScriptArguments, SFTConfig, SFTTrainer,
|
||||
TrlParser, get_kbit_device_map, get_quantization_config)
|
||||
|
||||
|
||||
def my_get_peft_config(model_args: ModelConfig):
|
||||
"""A version of get_peft_config that handles comma-separated target modules"""
|
||||
if model_args.use_peft is False:
|
||||
return None
|
||||
|
||||
# Import here to avoid issues if PEFT is not available
|
||||
try:
|
||||
from peft import LoraConfig
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
|
||||
"Make sure to run `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Fix the target_modules to be a list if it's a comma-separated string
|
||||
target_modules = model_args.lora_target_modules
|
||||
if isinstance(target_modules, str) and target_modules != "all-linear":
|
||||
# Convert comma-separated string to list
|
||||
target_modules = [module.strip() for module in target_modules.split(",")]
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=model_args.lora_task_type,
|
||||
r=model_args.lora_r,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
bias="none",
|
||||
use_rslora=model_args.use_rslora,
|
||||
use_dora=model_args.use_dora,
|
||||
modules_to_save=model_args.lora_modules_to_save,
|
||||
)
|
||||
|
||||
return peft_config
|
||||
|
||||
|
||||
# For intersection dataset processing
|
||||
def format_intersection_data(samples: dict) -> dict[str, list]:
|
||||
"""Format intersection dataset to match expected message format"""
|
||||
formatted_samples = {"messages": []}
|
||||
for idx in range(len(samples["image"])):
|
||||
image = samples["image"][idx].convert("RGB")
|
||||
label = str(samples["label"][idx])
|
||||
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
||||
]
|
||||
formatted_samples["messages"].append(message)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
for element in content:
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
if image is not None:
|
||||
# Handle dictionary with bytes
|
||||
if isinstance(image, dict) and "bytes" in image:
|
||||
pil_image = Image.open(io.BytesIO(image["bytes"]))
|
||||
image_inputs.append(pil_image.convert("RGB"))
|
||||
# Handle PIL Image objects
|
||||
elif hasattr(image, "convert"):
|
||||
image_inputs.append(image.convert("RGB"))
|
||||
return image_inputs
|
||||
|
||||
|
||||
def format_data(samples: dict) -> dict[str, list]:
|
||||
formatted_samples = {"messages": []}
|
||||
for cont in range(len(samples["question"])):
|
||||
images = []
|
||||
for img_path in samples["input_image_path"][cont]:
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
images.append({"type": "image", "image": image})
|
||||
except Exception as e:
|
||||
print(f"Error processing image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
formatted_samples["messages"].append(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": samples["context"][cont]}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": images
|
||||
+ [{"type": "text", "text": samples["question"][cont]}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": samples["output"][cont]}],
|
||||
},
|
||||
]
|
||||
)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def prepare_dataset(
|
||||
dataset: DatasetDict, dataset_name: str, dataset_train_split: str
|
||||
) -> DatasetDict:
|
||||
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
||||
zip_files = [f for f in all_files if f.endswith(".zip")]
|
||||
|
||||
for zip_filename in zip_files:
|
||||
zip_path = hf_hub_download(
|
||||
repo_id=dataset_name, filename=zip_filename, repo_type="dataset"
|
||||
)
|
||||
extract_folder = zip_filename.replace(".zip", "")
|
||||
os.makedirs(extract_folder, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_folder)
|
||||
|
||||
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
################
|
||||
# Model, Tokenizer & Processor
|
||||
################
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
processor.tokenizer.padding_side = "right"
|
||||
|
||||
# Use appropriate model class based on model name
|
||||
if "gemma-3n" in model_args.model_name_or_path.lower():
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
texts = []
|
||||
images_list = []
|
||||
|
||||
for example in examples:
|
||||
# Apply chat template to get text
|
||||
text = processor.apply_chat_template(
|
||||
example["messages"], tokenize=False, add_generation_prompt=False
|
||||
).strip()
|
||||
texts.append(text)
|
||||
|
||||
# Extract images
|
||||
if "images" in example: # single-image case
|
||||
images = [img.convert("RGB") for img in example["images"]]
|
||||
else: # multi-image case or intersection dataset
|
||||
images = process_vision_info(example["messages"])
|
||||
images_list.append(images)
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts, images=images_list, return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
|
||||
# Mask tokens for Gemma3n model
|
||||
if "gemma-3n" in model_args.model_name_or_path.lower():
|
||||
# Use Gemma3n specific token masking
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
if hasattr(processor.tokenizer, "image_token_id"):
|
||||
labels[labels == processor.tokenizer.image_token_id] = -100
|
||||
if hasattr(processor.tokenizer, "boi_token_id"):
|
||||
labels[labels == processor.tokenizer.boi_token_id] = -100
|
||||
if hasattr(processor.tokenizer, "eoi_token_id"):
|
||||
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
||||
else:
|
||||
# Original masking for other models
|
||||
image_token_id = [
|
||||
processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.tokenizer.special_tokens_map["boi_token"]
|
||||
)
|
||||
]
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
labels[labels == 262144] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
# Handle different dataset formats
|
||||
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
|
||||
dataset = prepare_dataset(
|
||||
dataset, script_args.dataset_name, script_args.dataset_train_split
|
||||
)
|
||||
elif script_args.dataset_name == "ariG23498/intersection-dataset":
|
||||
# Format intersection dataset
|
||||
dataset = dataset.map(
|
||||
format_intersection_data, batched=True, batch_size=4, num_proc=4
|
||||
)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split]
|
||||
if training_args.eval_strategy != "no"
|
||||
else None,
|
||||
processing_class=processor.tokenizer,
|
||||
peft_config=my_get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
if trainer.accelerator.is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
|
||||
|
||||
|
||||
def collate_fn(examples, processor):
|
||||
messages = list()
|
||||
for sample in examples:
|
||||
image = sample["image"].convert("RGB")
|
||||
label = str(sample["label"])
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
batch = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask the tokens that we do not want to include in the loss computation
|
||||
# -100 is ignored during categorical cross entropy loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == processor.tokenizer.image_token_id] = -100
|
||||
labels[labels == processor.tokenizer.boi_token_id] = -100
|
||||
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def freeze_layers(model):
|
||||
for name, param in model.named_parameters():
|
||||
if "attn" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def run_inference(val_dataset, processor, model, fname):
|
||||
# infer before training
|
||||
val_sample = random.choice(val_dataset)
|
||||
image = val_sample["image"].convert("RGB")
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
with torch.no_grad():
|
||||
generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
|
||||
plt.imshow(image)
|
||||
plt.axis("off")
|
||||
plt.title(f"Pred: {decoded}")
|
||||
plt.show()
|
||||
plt.savefig(f"outputs_fine_tune/{fname}")
|
||||
|
||||
|
||||
def main():
|
||||
model_id = "google/gemma-3n-E2B-it"
|
||||
processor = Gemma3nProcessor.from_pretrained(model_id)
|
||||
|
||||
# load the dataset
|
||||
dataset_id = "ariG23498/intersection-dataset"
|
||||
train_dataset = load_dataset(dataset_id, split="train")
|
||||
val_dataset = load_dataset(dataset_id, split="validation")
|
||||
|
||||
# create data loader
|
||||
partial_collate_fn = partial(collate_fn, processor=processor)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
)
|
||||
|
||||
# load the model and optimizer
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda")
|
||||
|
||||
run_inference(val_dataset, processor, model, "pred_before.png")
|
||||
|
||||
model = freeze_layers(model)
|
||||
|
||||
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
|
||||
|
||||
# Start Training
|
||||
accumulation_steps = 8
|
||||
for idx, batch in tqdm(enumerate(train_dataloader)):
|
||||
outputs = model(**batch.to(model.device))
|
||||
loss = outputs.loss / accumulation_steps
|
||||
if idx % 50 == 0:
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
count = 0
|
||||
for val_batch in val_dataloader:
|
||||
val_loss = val_loss + model(**val_batch.to(model.device)).loss
|
||||
count = count + 1
|
||||
val_loss = val_loss / count
|
||||
print(
|
||||
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
|
||||
)
|
||||
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
|
||||
|
||||
loss.backward()
|
||||
if idx % 8 == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,425 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Gemma3n Fine-tuning on All Modalities.ipynb
|
||||
|
||||
Automatically generated by Colab.
|
||||
|
||||
Original file is located at
|
||||
https://colab.research.google.com/drive/1iEZUJuvKJpGU8t50BqfkiCQmGkaR6gd4
|
||||
|
||||
# Fine-tune Gemma3n on FineVideo
|
||||
|
||||
In this notebook, we will see how to fine-tune Gemma3n an videos with audios inside.
|
||||
Using all three modalities is very costly compute-wise, so keep in mind that this is an educational tutorial to fit the model in 40GB VRAM.
|
||||
"""
|
||||
|
||||
!pip install -U -q timm transformers trl peft datasets
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
|
||||
|
||||
from trl import (
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
)
|
||||
|
||||
"""## Download videos and preprocessing
|
||||
|
||||
FineVideo is a quite large dataset, we don't need a ton of examples, so we stream the dataset, check the duration and download the videos shorter than 30 secs.
|
||||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
import json
|
||||
import os
|
||||
|
||||
dataset = load_dataset("HuggingFaceFV/finevideo", split="train", streaming=True)
|
||||
|
||||
|
||||
os.makedirs("videos", exist_ok=True)
|
||||
os.makedirs("metadata", exist_ok=True)
|
||||
|
||||
for idx, sample in enumerate(dataset):
|
||||
data = sample["json"]
|
||||
duration = data.get("duration_seconds", 0)
|
||||
if duration < 30:
|
||||
video_filename = f"videos/sample_{idx}.mp4"
|
||||
with open(video_filename, 'wb') as video_file:
|
||||
video_file.write(sample['mp4'])
|
||||
|
||||
json_filename = f"metadata/sample_{idx}.json"
|
||||
with open(json_filename, 'w') as json_file:
|
||||
json.dump(sample['json'], json_file)
|
||||
|
||||
print(f"Number of items in content/videos: {len(os.listdir('videos'))}")
|
||||
|
||||
"""In FineVideo some frames are dark so we downsample 6 frames and if we can't get meaningful videos we remove them."""
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
def is_dark(frame, threshold=10):
|
||||
return np.max(frame) < threshold # all pixels are very close to 0
|
||||
|
||||
def downsample_video(video_path):
|
||||
vidcap = cv2.VideoCapture(video_path)
|
||||
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
||||
|
||||
frames = []
|
||||
|
||||
# Generate 8 evenly spaced indices, skip first and last
|
||||
full_indices = np.linspace(0, total_frames - 1, 8, dtype=int)[1:-1]
|
||||
|
||||
for i in full_indices:
|
||||
found_valid = False
|
||||
for offset in [0, -1, 1, -2, 2]: # Try nearby frames if original is dark
|
||||
candidate_idx = i + offset
|
||||
if 0 <= candidate_idx < total_frames:
|
||||
vidcap.set(cv2.CAP_PROP_POS_FRAMES, candidate_idx)
|
||||
success, image = vidcap.read()
|
||||
if success:
|
||||
if not is_dark(image):
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image)
|
||||
timestamp = round(candidate_idx / fps, 2)
|
||||
frames.append((pil_image, timestamp))
|
||||
found_valid = True
|
||||
break
|
||||
if not found_valid:
|
||||
print(f"Warning: Could not find non-dark frame near index {i}")
|
||||
|
||||
vidcap.release()
|
||||
|
||||
# If still fewer than 8, try to top off by scanning more frames
|
||||
if len(frames) < 6:
|
||||
print("Trying to top off with additional non-dark frames...")
|
||||
idx = 0
|
||||
while len(frames) < 8 and idx < total_frames:
|
||||
vidcap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
success, image = vidcap.read()
|
||||
if success and not is_dark(image):
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image)
|
||||
timestamp = round(idx / fps, 2)
|
||||
# Avoid adding duplicate timestamps
|
||||
if not any(ts == timestamp for _, ts in frames):
|
||||
frames.append((pil_image, timestamp))
|
||||
idx += 1
|
||||
|
||||
return frames[:8] # Ensure exactly 8 frames
|
||||
|
||||
import os
|
||||
import glob
|
||||
|
||||
def remove_dark_videos(video_dir, metadata_dir, audio_dir):
|
||||
"""
|
||||
Remove videos (and their metadata/audio files) if all frames are dark.
|
||||
"""
|
||||
video_paths = glob.glob(os.path.join(video_dir, "*.mp4"))
|
||||
|
||||
for video_path in video_paths:
|
||||
filename = os.path.basename(video_path)
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
|
||||
frames = downsample_video(video_path)
|
||||
if len(frames) < 6:
|
||||
try:
|
||||
os.remove(video_path)
|
||||
print(f"Deleted: {video_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {video_path}: {e}")
|
||||
|
||||
metadata_path = os.path.join(metadata_dir, f"{base_name}.json")
|
||||
if os.path.exists(metadata_path):
|
||||
os.remove(metadata_path)
|
||||
|
||||
# Remove audio
|
||||
audio_path = os.path.join(audio_dir, f"{base_name}.wav")
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
|
||||
remove_dark_videos(
|
||||
video_dir="videos",
|
||||
metadata_dir="metadata",
|
||||
audio_dir="audios"
|
||||
)
|
||||
|
||||
"""Gemma-3n accepts video (image frames) and audio separately, so we strip audio from video."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
video_dir = "videos"
|
||||
audio_dir = "audios"
|
||||
os.makedirs(audio_dir, exist_ok=True)
|
||||
|
||||
for filename in os.listdir(video_dir):
|
||||
if not filename.endswith(".mp4"):
|
||||
continue
|
||||
|
||||
idx = filename.split("_")[1].split(".")[0]
|
||||
video_path = os.path.join(video_dir, filename)
|
||||
audio_path = os.path.join(audio_dir, f"sample_{idx}.wav")
|
||||
|
||||
subprocess.run([
|
||||
"ffmpeg", "-i", video_path,
|
||||
"-q:a", "0", "-map", "a",
|
||||
audio_path,
|
||||
"-y"
|
||||
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
|
||||
"""Construct a new dataset with audio, video, metadata (video categories). This dataset is very cool, it has some questions and answers, captions and more so get creative if you have the GPU VRAM to do so. Here we solve an easier task for educational purposes."""
|
||||
|
||||
from datasets import Dataset
|
||||
import json
|
||||
|
||||
def gen():
|
||||
meta_dir = "metadata"
|
||||
for filename in os.listdir(meta_dir):
|
||||
if not filename.endswith(".json"):
|
||||
continue
|
||||
|
||||
idx = filename.split("_")[1].split(".")[0]
|
||||
if os.path.exists(f"videos/sample_{idx}.mp4"):
|
||||
video_filename = f"sample_{idx}.mp4"
|
||||
audio_filename = f"sample_{idx}.wav"
|
||||
json_path = os.path.join(meta_dir, filename)
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
|
||||
yield {
|
||||
"video": video_filename,
|
||||
"audio": audio_filename,
|
||||
"content_parent_category": metadata["content_parent_category"],
|
||||
"sample_index": int(idx)
|
||||
}
|
||||
else:
|
||||
pass
|
||||
|
||||
dataset = Dataset.from_generator(gen)
|
||||
|
||||
"""We will speed-up and downsample the audios to save space during training."""
|
||||
|
||||
import torchaudio
|
||||
from torchaudio.transforms import Resample
|
||||
import os
|
||||
import torch
|
||||
|
||||
def preprocess_audio(audio_path, target_sample_rate=16000, max_duration_sec=5, speedup_factor=1.25):
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform.mean(dim=0, keepdim=True)
|
||||
|
||||
if sample_rate != target_sample_rate:
|
||||
resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
||||
waveform = resampler(waveform)
|
||||
sample_rate = target_sample_rate
|
||||
|
||||
if speedup_factor > 1.0:
|
||||
indices = torch.arange(0, waveform.shape[1], step=speedup_factor).long()
|
||||
if indices[-1] >= waveform.shape[1]:
|
||||
indices = indices[:-1]
|
||||
waveform = waveform[:, indices]
|
||||
|
||||
max_length = int(target_sample_rate * max_duration_sec)
|
||||
if waveform.shape[1] > max_length:
|
||||
waveform = waveform[:, :max_length]
|
||||
|
||||
torchaudio.save(audio_path, waveform, sample_rate)
|
||||
|
||||
for file_name in os.listdir("audios"):
|
||||
if file_name.lower().endswith(".wav"):
|
||||
audio_path = os.path.join("audios", file_name)
|
||||
preprocess_audio(audio_path)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=0.10, seed=42)
|
||||
|
||||
"""### Load the model
|
||||
|
||||
Make sure you have your Hugging Face token in your Colab secrets.
|
||||
"""
|
||||
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||
"google/gemma-3n-E2B-it", torch_dtype=torch.bfloat16,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"google/gemma-3n-E2B-it",
|
||||
)
|
||||
processor.tokenizer.padding_side = "right"
|
||||
|
||||
processor.tokenizer.all_special_ids
|
||||
|
||||
"""Write our dataset collator. We will train model to predict category of a video (which can be done easily). You can do much better things, for instance FineVideo has QnA section, you can train this model to do open-ended QnA if you have a big VRAM and a lot of patience. Open-ended tasks are harder to work with, and this notebook carries educational purposes on feeding different modalities.
|
||||
|
||||
In collator we also downsample videos to 6 frames, we have written the helper above. For better results you need more frames.
|
||||
"""
|
||||
|
||||
def collate_fn(examples):
|
||||
video_path = examples[0]["video"]
|
||||
audio_path = examples[0]["audio"]
|
||||
sample_idx = filename.split("_")[1].split(".")[0]
|
||||
frames = downsample_video(f"videos/{video_path}")
|
||||
|
||||
text = "Based on the video, predict the category of it."
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text}
|
||||
],
|
||||
},
|
||||
]
|
||||
# this is how video inference should be formatted in Gemma3n
|
||||
for frame in frames:
|
||||
image, timestamp = frame
|
||||
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
||||
timestamp = str(timestamp).replace(".", "_")
|
||||
image.save(f"image_idx_{sample_idx}_{timestamp}.png")
|
||||
message[0]["content"].append({"type": "image", "url": f"image_idx_{sample_idx}_{timestamp}.png"})
|
||||
|
||||
message[0]["content"].append({"type": "audio", "audio": f"audios/{audio_path}"})
|
||||
message.append({"role": "assistant", "content": [{"type": "text", "text": examples[0]["content_parent_category"]}]})
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(model.device)
|
||||
|
||||
labels = inputs["input_ids"].clone()
|
||||
special_token_ids = processor.tokenizer.all_special_ids
|
||||
|
||||
special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device)
|
||||
mask = torch.isin(labels, special_token_ids_tensor)
|
||||
labels[mask] = -100
|
||||
|
||||
inputs["labels"] = labels
|
||||
if torch.all(inputs["pixel_values"] == 0):
|
||||
print("Frames are dark")
|
||||
|
||||
return inputs
|
||||
|
||||
"""## Training
|
||||
|
||||
We do LoRA fine-tuning again to save up on space.
|
||||
"""
|
||||
|
||||
from peft import LoraConfig
|
||||
peft_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
use_rslora=False,
|
||||
use_dora=False,
|
||||
modules_to_save=None
|
||||
)
|
||||
|
||||
model.gradient_checkpointing_disable()
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir="/content/gemma-3n-finevideo",
|
||||
eval_strategy='epoch',
|
||||
per_device_train_batch_size=1,
|
||||
per_device_eval_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=False,
|
||||
learning_rate=1e-05,
|
||||
num_train_epochs=3.0,
|
||||
logging_steps=10,
|
||||
save_steps=100,
|
||||
bf16=True,
|
||||
report_to=["tensorboard"],
|
||||
dataset_kwargs={'skip_prepare_dataset': True},
|
||||
remove_unused_columns=False,
|
||||
max_seq_length=None,
|
||||
push_to_hub=True,
|
||||
dataloader_pin_memory=False,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"] if training_args.eval_strategy != "no" else None,
|
||||
processing_class=processor.tokenizer,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
"""Test the model with a video of snowboarding."""
|
||||
|
||||
!wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4
|
||||
|
||||
model = trainer.model # trainer has the adapter
|
||||
|
||||
"""Strip audio and downsample video."""
|
||||
|
||||
audio_path = "/content/test_audio.wav"
|
||||
subprocess.run([
|
||||
"ffmpeg", "-i", "/content/IMG_8137.mp4",
|
||||
"-q:a", "0", "-map", "a",
|
||||
f"{audio_path}",
|
||||
"-y"
|
||||
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
|
||||
frames = downsample_video("/content/IMG_8137.mp4")
|
||||
|
||||
# repeat the chat template
|
||||
text = "Based on the video, predict the category of it."
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text}
|
||||
],
|
||||
},
|
||||
]
|
||||
for frame in frames:
|
||||
image, timestamp = frame
|
||||
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
||||
timestamp = str(timestamp).replace(".", "_")
|
||||
image.save(f"test_frame_{timestamp}.png")
|
||||
message[0]["content"].append({"type": "image", "url": f"test_frame_{timestamp}.png"})
|
||||
|
||||
message[0]["content"].append({"type": "audio", "audio": f"{audio_path}"})
|
||||
|
||||
message
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(model.device).to(model.dtype)
|
||||
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
with torch.inference_mode():
|
||||
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
print(decoded)
|
||||
|
||||
"""Thanks a lot for reading! Keep training the model further with more data or unfreeze the layers for better performance 💗"""
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
title: Gemma 4 31B It
|
||||
emoji: 🚀
|
||||
colorFrom: blue
|
||||
colorTo: green
|
||||
sdk: gradio
|
||||
sdk_version: 6.12.0
|
||||
python_version: "3.12.12"
|
||||
app_file: app.py
|
||||
pinned: false
|
||||
---
|
||||
|
||||
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
||||
@@ -0,0 +1,303 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from threading import Thread
|
||||
|
||||
import gradio as gr
|
||||
import spaces
|
||||
import torch
|
||||
from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature
|
||||
from transformers.generation.streamers import TextIteratorStreamer
|
||||
|
||||
MODEL_ID = "google/gemma-4-31b-it"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(MODEL_ID, device_map="auto", dtype=torch.bfloat16)
|
||||
|
||||
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
|
||||
VIDEO_FILE_TYPES = (".mp4", ".mov", ".avi", ".webm")
|
||||
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
|
||||
|
||||
THINKING_START = "<|channel>"
|
||||
THINKING_END = "<channel|>"
|
||||
|
||||
# Special tokens to strip from decoded output (keeping thinking delimiters
|
||||
# so that Gradio's reasoning_tags can find them on the frontend).
|
||||
_KEEP_TOKENS = {THINKING_START, THINKING_END}
|
||||
_STRIP_TOKENS = sorted(
|
||||
(t for t in processor.tokenizer.all_special_tokens if t not in _KEEP_TOKENS),
|
||||
key=len,
|
||||
reverse=True, # longest first to avoid partial matches
|
||||
)
|
||||
|
||||
|
||||
def _strip_special_tokens(text: str) -> str:
|
||||
for tok in _STRIP_TOKENS:
|
||||
text = text.replace(tok, "")
|
||||
return text
|
||||
|
||||
|
||||
def _classify_file(path: str) -> str | None:
|
||||
"""Return media type string for a file path, or None if unsupported."""
|
||||
lower = path.lower()
|
||||
if lower.endswith(IMAGE_FILE_TYPES):
|
||||
return "image"
|
||||
if lower.endswith(VIDEO_FILE_TYPES):
|
||||
return "video"
|
||||
return None
|
||||
|
||||
|
||||
def process_new_user_message(message: dict) -> list[dict]:
|
||||
"""Build content list from the new user message with URL-based media references."""
|
||||
content: list[dict] = []
|
||||
for path in message.get("files", []):
|
||||
kind = _classify_file(path)
|
||||
if kind:
|
||||
content.append({"type": kind, "url": path})
|
||||
content.append({"type": "text", "text": message.get("text", "")})
|
||||
return content
|
||||
|
||||
|
||||
def process_history(history: list[dict]) -> list[dict]:
|
||||
"""Walk Gradio 6 history and build message list with URL-based media references."""
|
||||
messages: list[dict] = []
|
||||
|
||||
for item in history:
|
||||
if item["role"] == "assistant":
|
||||
text_parts = [p["text"] for p in item["content"] if p.get("type") == "text"]
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": " ".join(text_parts)}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
user_content: list[dict] = []
|
||||
for part in item["content"]:
|
||||
if part.get("type") == "text":
|
||||
user_content.append({"type": "text", "text": part["text"]})
|
||||
elif part.get("type") == "file":
|
||||
filepath = part["file"]["path"]
|
||||
kind = _classify_file(filepath)
|
||||
if kind:
|
||||
user_content.append({"type": kind, "url": filepath})
|
||||
if user_content:
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@spaces.GPU(duration=180)
|
||||
@torch.inference_mode()
|
||||
def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
|
||||
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
processor,
|
||||
timeout=30.0,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=not thinking,
|
||||
)
|
||||
generate_kwargs = {
|
||||
**inputs,
|
||||
"streamer": streamer,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"disable_compile": True,
|
||||
}
|
||||
|
||||
exception_holder: list[Exception] = []
|
||||
|
||||
def _generate() -> None:
|
||||
try:
|
||||
model.generate(**generate_kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
exception_holder.append(e)
|
||||
|
||||
thread = Thread(target=_generate)
|
||||
thread.start()
|
||||
|
||||
chunks: list[str] = []
|
||||
for text in streamer:
|
||||
chunks.append(text)
|
||||
accumulated = "".join(chunks)
|
||||
if thinking:
|
||||
yield _strip_special_tokens(accumulated)
|
||||
else:
|
||||
yield accumulated
|
||||
|
||||
thread.join()
|
||||
if exception_holder:
|
||||
msg = f"Generation failed: {exception_holder[0]}"
|
||||
raise gr.Error(msg)
|
||||
|
||||
|
||||
def validate_input(message: dict) -> dict:
|
||||
has_text = bool(message.get("text", "").strip())
|
||||
has_files = bool(message.get("files"))
|
||||
if not (has_text or has_files):
|
||||
return gr.validate(has_text, "Please enter a message or upload a file.")
|
||||
|
||||
files = message.get("files", [])
|
||||
kinds = [_classify_file(f) for f in files]
|
||||
kinds = [k for k in kinds if k is not None]
|
||||
unique_kinds = set(kinds)
|
||||
|
||||
if len(unique_kinds) > 1:
|
||||
return gr.validate(False, "Please upload only one type of media (images or video) at a time.")
|
||||
if kinds.count("video") > 1:
|
||||
return gr.validate(False, "Only one video file can be uploaded at a time.")
|
||||
|
||||
return gr.validate(True, "")
|
||||
|
||||
|
||||
def _has_media_type(messages: list[dict], media_type: str) -> bool:
|
||||
"""Check if any message contains a content entry of the given media type."""
|
||||
return any(
|
||||
c.get("type") == media_type for m in messages for c in (m["content"] if isinstance(m["content"], list) else [])
|
||||
)
|
||||
|
||||
|
||||
def generate(
|
||||
message: dict,
|
||||
history: list[dict],
|
||||
thinking: bool = False,
|
||||
max_new_tokens: int = 1024,
|
||||
max_soft_tokens: int = 280,
|
||||
system_prompt: str = "",
|
||||
) -> Iterator[str]:
|
||||
|
||||
messages: list[dict] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
|
||||
|
||||
messages.extend(process_history(history))
|
||||
messages.append({"role": "user", "content": process_new_user_message(message)})
|
||||
|
||||
template_kwargs: dict = {
|
||||
"tokenize": True,
|
||||
"return_dict": True,
|
||||
"return_tensors": "pt",
|
||||
"add_generation_prompt": True,
|
||||
"processor_kwargs": {"images_kwargs": {"max_soft_tokens": max_soft_tokens}},
|
||||
}
|
||||
if _has_media_type(messages, "video"):
|
||||
template_kwargs["load_audio_from_video"] = False
|
||||
if thinking:
|
||||
template_kwargs["enable_thinking"] = True
|
||||
|
||||
inputs = processor.apply_chat_template(messages, **template_kwargs)
|
||||
|
||||
n_tokens = inputs["input_ids"].shape[1]
|
||||
if n_tokens > MAX_INPUT_TOKENS:
|
||||
msg = f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
|
||||
raise gr.Error(msg)
|
||||
|
||||
yield from _generate_on_gpu(inputs=inputs, max_new_tokens=max_new_tokens, thinking=thinking)
|
||||
|
||||
|
||||
examples = [
|
||||
# --- Text-only examples ---
|
||||
[
|
||||
{
|
||||
"text": "What is the capital of France?",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "What is the water formula?",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "Explain quantum entanglement in simple terms.",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "I want to do a car wash that is 50 meters away, should I walk or drive?",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "Write a poem about beer with 4 stanzas. Format the title as an H2 markdown heading and bold the first line of each stanza.",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
# --- Single-image examples ---
|
||||
[
|
||||
{
|
||||
"text": "Describe this image.",
|
||||
"files": ["https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg"],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "What is the city in this image? Describe what you see.",
|
||||
"files": ["https://imgmd.net/images/v1/guia/1698673/rio-de-janeiro-4-c.jpg"],
|
||||
}
|
||||
],
|
||||
# --- Multi-image examples ---
|
||||
[
|
||||
{
|
||||
"text": "What are the key similarities between these three images?",
|
||||
"files": [
|
||||
"https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg",
|
||||
"https://ogimg.infoglobo.com.br/in/12547538-502-0e0/FT1086A/94-8705-14.jpg",
|
||||
"https://amazonasatual.com.br/wp-content/uploads/2021/01/Pele.jpg",
|
||||
],
|
||||
}
|
||||
],
|
||||
# --- Video examples ---
|
||||
[
|
||||
{
|
||||
"text": "What is happening in this video?",
|
||||
"files": ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
fn=generate,
|
||||
validator=validate_input,
|
||||
chatbot=gr.Chatbot(
|
||||
scale=1,
|
||||
latex_delimiters=[
|
||||
{"left": "$$", "right": "$$", "display": True},
|
||||
{"left": "$", "right": "$", "display": False},
|
||||
{"left": "\\(", "right": "\\)", "display": False},
|
||||
{"left": "\\[", "right": "\\]", "display": True},
|
||||
],
|
||||
reasoning_tags=[(THINKING_START, THINKING_END)],
|
||||
),
|
||||
textbox=gr.MultimodalTextbox(
|
||||
sources=["upload"],
|
||||
file_types=[*IMAGE_FILE_TYPES, *VIDEO_FILE_TYPES],
|
||||
file_count="multiple",
|
||||
autofocus=True,
|
||||
),
|
||||
multimodal=True,
|
||||
additional_inputs=[
|
||||
gr.Checkbox(label="Thinking", value=False),
|
||||
gr.Slider(label="Max New Tokens", minimum=100, maximum=4000, step=10, value=2000),
|
||||
gr.Dropdown(
|
||||
label="Image Token Budget",
|
||||
info="Higher values preserve more visual detail (useful for OCR/documents). Lower values are faster.",
|
||||
choices=[70, 140, 280, 560, 1120],
|
||||
value=280,
|
||||
),
|
||||
gr.Textbox(label="System Prompt", value=""),
|
||||
],
|
||||
additional_inputs_accordion=gr.Accordion("Settings", open=True),
|
||||
stop_btn=False,
|
||||
title="Gemma 4 31B It",
|
||||
examples=examples,
|
||||
run_examples_on_click=False,
|
||||
cache_examples=False,
|
||||
delete_cache=(1800, 1800),
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(css_paths="style.css", max_file_size="20mb")
|
||||
@@ -0,0 +1,362 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-hashes --no-dev --group hf-spaces --no-emit-package typer-slim --no-emit-package spaces -o requirements.txt
|
||||
accelerate==1.13.0
|
||||
# via gemma-4-31b-it
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.5
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.13.0
|
||||
# via
|
||||
# gradio
|
||||
# httpx
|
||||
# mcp
|
||||
# sse-starlette
|
||||
# starlette
|
||||
attrs==26.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonschema
|
||||
# referencing
|
||||
audioop-lts==0.2.2 ; python_full_version >= '3.13'
|
||||
# via gradio
|
||||
brotli==1.2.0
|
||||
# via gradio
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==2.0.0 ; platform_python_implementation != 'PyPy'
|
||||
# via cryptography
|
||||
charset-normalizer==3.4.7
|
||||
# via requests
|
||||
click==8.3.2
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.7
|
||||
# via pyjwt
|
||||
datasets==4.8.4
|
||||
dill==0.4.1
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
fastapi==0.136.0
|
||||
# via gradio
|
||||
filelock==3.28.0
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# torch
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2026.2.0
|
||||
# via
|
||||
# datasets
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# torch
|
||||
gradio==6.12.0
|
||||
# via
|
||||
# gemma-4-31b-it
|
||||
# spaces
|
||||
gradio-client==2.4.1
|
||||
# via
|
||||
# gradio
|
||||
# hf-gradio
|
||||
groovy==0.1.2
|
||||
# via gradio
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hf-gradio==0.4.0
|
||||
# via gradio
|
||||
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# safehttpx
|
||||
# spaces
|
||||
httpx-sse==0.4.3
|
||||
# via mcp
|
||||
huggingface-hub==1.11.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# gradio-client
|
||||
# tokenizers
|
||||
# transformers
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# gradio
|
||||
# torch
|
||||
jsonschema==4.26.0
|
||||
# via mcp
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
markdown-it-py==4.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# gradio
|
||||
# jinja2
|
||||
mcp==1.27.0
|
||||
# via gradio
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.7.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.19
|
||||
# via datasets
|
||||
networkx==3.6.1
|
||||
# via torch
|
||||
numpy==2.4.4
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# pandas
|
||||
# torchvision
|
||||
# transformers
|
||||
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
# torch
|
||||
nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
orjson==3.11.8
|
||||
# via gradio
|
||||
packaging==26.1
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# spaces
|
||||
# transformers
|
||||
pandas==3.0.2
|
||||
# via
|
||||
# datasets
|
||||
# gradio
|
||||
pillow==12.2.0
|
||||
# via
|
||||
# gradio
|
||||
# torchvision
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psutil==5.9.8
|
||||
# via
|
||||
# accelerate
|
||||
# spaces
|
||||
pyarrow==23.0.1
|
||||
# via datasets
|
||||
pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
|
||||
# via cffi
|
||||
pydantic==2.12.5
|
||||
# via
|
||||
# fastapi
|
||||
# gradio
|
||||
# mcp
|
||||
# pydantic-settings
|
||||
# spaces
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pydantic-settings==2.13.1
|
||||
# via mcp
|
||||
pydub==0.25.1
|
||||
# via gradio
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.1
|
||||
# via mcp
|
||||
python-dateutil==2.9.0.post0
|
||||
# via pandas
|
||||
python-dotenv==1.2.2
|
||||
# via pydantic-settings
|
||||
python-multipart==0.0.26
|
||||
# via
|
||||
# gradio
|
||||
# mcp
|
||||
pytz==2026.1.post1
|
||||
# via gradio
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2026.4.4
|
||||
# via transformers
|
||||
requests==2.33.1
|
||||
# via
|
||||
# datasets
|
||||
# spaces
|
||||
rich==15.0.0
|
||||
# via typer
|
||||
rpds-py==0.30.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safehttpx==0.1.7
|
||||
# via gradio
|
||||
safetensors==0.7.0
|
||||
# via
|
||||
# accelerate
|
||||
# transformers
|
||||
semantic-version==2.10.0
|
||||
# via gradio
|
||||
setuptools==82.0.1
|
||||
# via torch
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sse-starlette==3.3.4
|
||||
# via mcp
|
||||
starlette==1.0.0
|
||||
# via
|
||||
# fastapi
|
||||
# gradio
|
||||
# mcp
|
||||
# sse-starlette
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
tokenizers==0.22.2
|
||||
# via transformers
|
||||
tomlkit==0.14.0
|
||||
# via gradio
|
||||
torch==2.9.1
|
||||
# via
|
||||
# accelerate
|
||||
# gemma-4-31b-it
|
||||
# torchvision
|
||||
torchcodec==0.9.1
|
||||
# via gemma-4-31b-it
|
||||
torchvision==0.24.1
|
||||
# via gemma-4-31b-it
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==5.5.4
|
||||
# via gemma-4-31b-it
|
||||
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
typer==0.24.1
|
||||
# via
|
||||
# gradio
|
||||
# hf-gradio
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# fastapi
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# spaces
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
tzdata==2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
|
||||
# via pandas
|
||||
urllib3==2.6.3
|
||||
# via requests
|
||||
uvicorn==0.44.0
|
||||
# via
|
||||
# gradio
|
||||
# mcp
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.23.0
|
||||
# via aiohttp
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
title: Gemma 4 E4B It
|
||||
emoji: 🚀
|
||||
colorFrom: blue
|
||||
colorTo: green
|
||||
sdk: gradio
|
||||
sdk_version: 6.12.0
|
||||
python_version: "3.12.12"
|
||||
app_file: app.py
|
||||
pinned: false
|
||||
---
|
||||
|
||||
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
||||
@@ -0,0 +1,322 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from threading import Thread
|
||||
|
||||
import gradio as gr
|
||||
import spaces
|
||||
import torch
|
||||
from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature
|
||||
from transformers.generation.streamers import TextIteratorStreamer
|
||||
|
||||
MODEL_ID = "google/gemma-4-e4b-it"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
|
||||
model = AutoModelForMultimodalLM.from_pretrained(MODEL_ID, device_map="auto", dtype=torch.bfloat16)
|
||||
|
||||
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
|
||||
AUDIO_FILE_TYPES = (".wav", ".mp3", ".flac", ".ogg")
|
||||
VIDEO_FILE_TYPES = (".mp4", ".mov", ".avi", ".webm")
|
||||
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
|
||||
|
||||
THINKING_START = "<|channel>"
|
||||
THINKING_END = "<channel|>"
|
||||
|
||||
# Special tokens to strip from decoded output (keeping thinking delimiters
|
||||
# so that Gradio's reasoning_tags can find them on the frontend).
|
||||
_KEEP_TOKENS = {THINKING_START, THINKING_END}
|
||||
_STRIP_TOKENS = sorted(
|
||||
(t for t in processor.tokenizer.all_special_tokens if t not in _KEEP_TOKENS),
|
||||
key=len,
|
||||
reverse=True, # longest first to avoid partial matches
|
||||
)
|
||||
|
||||
|
||||
def _strip_special_tokens(text: str) -> str:
|
||||
for tok in _STRIP_TOKENS:
|
||||
text = text.replace(tok, "")
|
||||
return text
|
||||
|
||||
|
||||
def _classify_file(path: str) -> str | None:
|
||||
"""Return media type string for a file path, or None if unsupported."""
|
||||
lower = path.lower()
|
||||
if lower.endswith(IMAGE_FILE_TYPES):
|
||||
return "image"
|
||||
if lower.endswith(AUDIO_FILE_TYPES):
|
||||
return "audio"
|
||||
if lower.endswith(VIDEO_FILE_TYPES):
|
||||
return "video"
|
||||
return None
|
||||
|
||||
|
||||
def process_new_user_message(message: dict) -> list[dict]:
|
||||
"""Build content list from the new user message with URL-based media references."""
|
||||
content: list[dict] = []
|
||||
for path in message.get("files", []):
|
||||
kind = _classify_file(path)
|
||||
if kind:
|
||||
content.append({"type": kind, "url": path})
|
||||
content.append({"type": "text", "text": message.get("text", "")})
|
||||
return content
|
||||
|
||||
|
||||
def process_history(history: list[dict]) -> list[dict]:
|
||||
"""Walk Gradio 6 history and build message list with URL-based media references."""
|
||||
messages: list[dict] = []
|
||||
|
||||
for item in history:
|
||||
if item["role"] == "assistant":
|
||||
text_parts = [p["text"] for p in item["content"] if p.get("type") == "text"]
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": " ".join(text_parts)}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
user_content: list[dict] = []
|
||||
for part in item["content"]:
|
||||
if part.get("type") == "text":
|
||||
user_content.append({"type": "text", "text": part["text"]})
|
||||
elif part.get("type") == "file":
|
||||
filepath = part["file"]["path"]
|
||||
kind = _classify_file(filepath)
|
||||
if kind:
|
||||
user_content.append({"type": kind, "url": filepath})
|
||||
if user_content:
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@spaces.GPU(duration=120)
|
||||
@torch.inference_mode()
|
||||
def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
|
||||
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
processor,
|
||||
timeout=30.0,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=not thinking,
|
||||
)
|
||||
generate_kwargs = {
|
||||
**inputs,
|
||||
"streamer": streamer,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"disable_compile": True,
|
||||
}
|
||||
|
||||
exception_holder: list[Exception] = []
|
||||
|
||||
def _generate() -> None:
|
||||
try:
|
||||
model.generate(**generate_kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
exception_holder.append(e)
|
||||
|
||||
thread = Thread(target=_generate)
|
||||
thread.start()
|
||||
|
||||
chunks: list[str] = []
|
||||
for text in streamer:
|
||||
chunks.append(text)
|
||||
accumulated = "".join(chunks)
|
||||
if thinking:
|
||||
yield _strip_special_tokens(accumulated)
|
||||
else:
|
||||
yield accumulated
|
||||
|
||||
thread.join()
|
||||
if exception_holder:
|
||||
msg = f"Generation failed: {exception_holder[0]}"
|
||||
raise gr.Error(msg)
|
||||
|
||||
|
||||
# FBT003 is suppressed below: gr.validate API takes bool as first positional arg.
|
||||
def validate_input(message: dict) -> dict:
|
||||
has_text = bool(message.get("text", "").strip())
|
||||
has_files = bool(message.get("files"))
|
||||
if not (has_text or has_files):
|
||||
return gr.validate(False, "Please enter a message or upload a file.") # noqa: FBT003
|
||||
|
||||
files = message.get("files", [])
|
||||
kinds = [_classify_file(f) for f in files]
|
||||
kinds = [k for k in kinds if k is not None]
|
||||
unique_kinds = set(kinds)
|
||||
|
||||
if len(unique_kinds) > 1:
|
||||
return gr.validate(False, "Please upload only one type of media (images, audio, or video) at a time.") # noqa: FBT003
|
||||
if kinds.count("audio") > 1:
|
||||
return gr.validate(False, "Only one audio file can be uploaded at a time.") # noqa: FBT003
|
||||
if kinds.count("video") > 1:
|
||||
return gr.validate(False, "Only one video file can be uploaded at a time.") # noqa: FBT003
|
||||
|
||||
return gr.validate(True, "") # noqa: FBT003
|
||||
|
||||
|
||||
def _has_media_type(messages: list[dict], media_type: str) -> bool:
|
||||
"""Check if any message contains a content entry of the given media type."""
|
||||
return any(c.get("type") == media_type for m in messages for c in m["content"])
|
||||
|
||||
|
||||
def generate(
|
||||
message: dict,
|
||||
history: list[dict],
|
||||
thinking: bool = False,
|
||||
max_new_tokens: int = 1024,
|
||||
max_soft_tokens: int = 280,
|
||||
system_prompt: str = "",
|
||||
) -> Iterator[str]:
|
||||
messages: list[dict] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
|
||||
|
||||
messages.extend(process_history(history))
|
||||
messages.append({"role": "user", "content": process_new_user_message(message)})
|
||||
|
||||
template_kwargs: dict = {
|
||||
"tokenize": True,
|
||||
"return_dict": True,
|
||||
"return_tensors": "pt",
|
||||
"add_generation_prompt": True,
|
||||
"load_audio_from_video": _has_media_type(messages, "video"),
|
||||
"processor_kwargs": {"images_kwargs": {"max_soft_tokens": max_soft_tokens}},
|
||||
}
|
||||
if thinking:
|
||||
template_kwargs["enable_thinking"] = True
|
||||
|
||||
inputs = processor.apply_chat_template(messages, **template_kwargs)
|
||||
|
||||
n_tokens = inputs["input_ids"].shape[1]
|
||||
if n_tokens > MAX_INPUT_TOKENS:
|
||||
msg = f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
|
||||
raise gr.Error(msg)
|
||||
|
||||
yield from _generate_on_gpu(inputs=inputs, max_new_tokens=max_new_tokens, thinking=thinking)
|
||||
|
||||
|
||||
examples = [
|
||||
# --- Text-only examples ---
|
||||
[
|
||||
{
|
||||
"text": "What is the capital of France?",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "What is the water formula?",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "Explain quantum entanglement in simple terms.",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "I want to do a car wash that is 50 meters away, should I walk or drive?",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "Write a poem about beer with 4 stanzas. Format the title as an H2 markdown heading and bold the first line of each stanza.",
|
||||
"files": [],
|
||||
}
|
||||
],
|
||||
# --- Single-image examples ---
|
||||
[
|
||||
{
|
||||
"text": "Describe this image.",
|
||||
"files": ["https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg"],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "What is the city in this image? Describe what you see.",
|
||||
"files": ["https://imgmd.net/images/v1/guia/1698673/rio-de-janeiro-4-c.jpg"],
|
||||
}
|
||||
],
|
||||
# --- Multi-image examples ---
|
||||
[
|
||||
{
|
||||
"text": "What are the key similarities between these three images?",
|
||||
"files": [
|
||||
"https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg",
|
||||
"https://ogimg.infoglobo.com.br/in/12547538-502-0e0/FT1086A/94-8705-14.jpg",
|
||||
"https://amazonasatual.com.br/wp-content/uploads/2021/01/Pele.jpg",
|
||||
],
|
||||
}
|
||||
],
|
||||
# --- Audio examples ---
|
||||
[
|
||||
{
|
||||
"text": "Transcribe the audio.",
|
||||
"files": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3"
|
||||
],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "Translate to Dutch.",
|
||||
"files": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3"
|
||||
],
|
||||
}
|
||||
],
|
||||
# --- Video examples ---
|
||||
[
|
||||
{
|
||||
"text": "What is happening in this video?",
|
||||
"files": ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
fn=generate,
|
||||
validator=validate_input,
|
||||
chatbot=gr.Chatbot(
|
||||
scale=1,
|
||||
latex_delimiters=[
|
||||
{"left": "$$", "right": "$$", "display": True},
|
||||
{"left": "$", "right": "$", "display": False},
|
||||
{"left": "\\(", "right": "\\)", "display": False},
|
||||
{"left": "\\[", "right": "\\]", "display": True},
|
||||
],
|
||||
reasoning_tags=[(THINKING_START, THINKING_END)],
|
||||
),
|
||||
textbox=gr.MultimodalTextbox(
|
||||
sources=["upload", "microphone"],
|
||||
file_types=[*IMAGE_FILE_TYPES, *AUDIO_FILE_TYPES, *VIDEO_FILE_TYPES],
|
||||
file_count="multiple",
|
||||
autofocus=True,
|
||||
),
|
||||
multimodal=True,
|
||||
additional_inputs=[
|
||||
gr.Checkbox(label="Thinking", value=False),
|
||||
gr.Slider(label="Max New Tokens", minimum=100, maximum=4000, step=10, value=2000),
|
||||
gr.Dropdown(
|
||||
label="Image Token Budget",
|
||||
info="Higher values preserve more visual detail (useful for OCR/documents). Lower values are faster.",
|
||||
choices=[70, 140, 280, 560, 1120],
|
||||
value=280,
|
||||
),
|
||||
gr.Textbox(label="System Prompt", value=""),
|
||||
],
|
||||
additional_inputs_accordion=gr.Accordion("Settings", open=True),
|
||||
stop_btn=False,
|
||||
title="Gemma 4 E4B It",
|
||||
examples=examples,
|
||||
run_examples_on_click=False,
|
||||
cache_examples=False,
|
||||
delete_cache=(1800, 1800),
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(css_paths="style.css", max_file_size="20MB")
|
||||
@@ -0,0 +1,362 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-hashes --no-dev --group hf-spaces --no-emit-package typer-slim --no-emit-package spaces -o requirements.txt
|
||||
accelerate==1.13.0
|
||||
# via gemma-4-e4b-it
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.5
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.13.0
|
||||
# via
|
||||
# gradio
|
||||
# httpx
|
||||
# mcp
|
||||
# sse-starlette
|
||||
# starlette
|
||||
attrs==26.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonschema
|
||||
# referencing
|
||||
audioop-lts==0.2.2 ; python_full_version >= '3.13'
|
||||
# via gradio
|
||||
brotli==1.2.0
|
||||
# via gradio
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==2.0.0 ; platform_python_implementation != 'PyPy'
|
||||
# via cryptography
|
||||
charset-normalizer==3.4.7
|
||||
# via requests
|
||||
click==8.3.2
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.7
|
||||
# via pyjwt
|
||||
datasets==4.8.4
|
||||
dill==0.4.1
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
fastapi==0.136.0
|
||||
# via gradio
|
||||
filelock==3.28.0
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# torch
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2026.2.0
|
||||
# via
|
||||
# datasets
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# torch
|
||||
gradio==6.12.0
|
||||
# via
|
||||
# gemma-4-e4b-it
|
||||
# spaces
|
||||
gradio-client==2.4.1
|
||||
# via
|
||||
# gradio
|
||||
# hf-gradio
|
||||
groovy==0.1.2
|
||||
# via gradio
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hf-gradio==0.4.0
|
||||
# via gradio
|
||||
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# safehttpx
|
||||
# spaces
|
||||
httpx-sse==0.4.3
|
||||
# via mcp
|
||||
huggingface-hub==1.11.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# gradio-client
|
||||
# tokenizers
|
||||
# transformers
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# gradio
|
||||
# torch
|
||||
jsonschema==4.26.0
|
||||
# via mcp
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
markdown-it-py==4.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# gradio
|
||||
# jinja2
|
||||
mcp==1.27.0
|
||||
# via gradio
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.7.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.19
|
||||
# via datasets
|
||||
networkx==3.6.1
|
||||
# via torch
|
||||
numpy==2.4.4
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# pandas
|
||||
# torchvision
|
||||
# transformers
|
||||
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
# torch
|
||||
nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
orjson==3.11.8
|
||||
# via gradio
|
||||
packaging==26.1
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# spaces
|
||||
# transformers
|
||||
pandas==3.0.2
|
||||
# via
|
||||
# datasets
|
||||
# gradio
|
||||
pillow==12.2.0
|
||||
# via
|
||||
# gradio
|
||||
# torchvision
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psutil==5.9.8
|
||||
# via
|
||||
# accelerate
|
||||
# spaces
|
||||
pyarrow==23.0.1
|
||||
# via datasets
|
||||
pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
|
||||
# via cffi
|
||||
pydantic==2.12.5
|
||||
# via
|
||||
# fastapi
|
||||
# gradio
|
||||
# mcp
|
||||
# pydantic-settings
|
||||
# spaces
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pydantic-settings==2.13.1
|
||||
# via mcp
|
||||
pydub==0.25.1
|
||||
# via gradio
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.1
|
||||
# via mcp
|
||||
python-dateutil==2.9.0.post0
|
||||
# via pandas
|
||||
python-dotenv==1.2.2
|
||||
# via pydantic-settings
|
||||
python-multipart==0.0.26
|
||||
# via
|
||||
# gradio
|
||||
# mcp
|
||||
pytz==2026.1.post1
|
||||
# via gradio
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# gradio
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2026.4.4
|
||||
# via transformers
|
||||
requests==2.33.1
|
||||
# via
|
||||
# datasets
|
||||
# spaces
|
||||
rich==15.0.0
|
||||
# via typer
|
||||
rpds-py==0.30.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safehttpx==0.1.7
|
||||
# via gradio
|
||||
safetensors==0.7.0
|
||||
# via
|
||||
# accelerate
|
||||
# transformers
|
||||
semantic-version==2.10.0
|
||||
# via gradio
|
||||
setuptools==82.0.1
|
||||
# via torch
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sse-starlette==3.3.4
|
||||
# via mcp
|
||||
starlette==1.0.0
|
||||
# via
|
||||
# fastapi
|
||||
# gradio
|
||||
# mcp
|
||||
# sse-starlette
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
tokenizers==0.22.2
|
||||
# via transformers
|
||||
tomlkit==0.14.0
|
||||
# via gradio
|
||||
torch==2.9.1
|
||||
# via
|
||||
# accelerate
|
||||
# gemma-4-e4b-it
|
||||
# torchvision
|
||||
torchcodec==0.9.1
|
||||
# via gemma-4-e4b-it
|
||||
torchvision==0.24.1
|
||||
# via gemma-4-e4b-it
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==5.5.4
|
||||
# via gemma-4-e4b-it
|
||||
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
typer==0.24.1
|
||||
# via
|
||||
# gradio
|
||||
# hf-gradio
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# fastapi
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# spaces
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
tzdata==2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
|
||||
# via pandas
|
||||
urllib3==2.6.3
|
||||
# via requests
|
||||
uvicorn==0.44.0
|
||||
# via
|
||||
# gradio
|
||||
# mcp
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.23.0
|
||||
# via aiohttp
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gemma4 import *
|
||||
from .feature_extraction_gemma4 import *
|
||||
from .image_processing_gemma4 import *
|
||||
from .image_processing_pil_gemma4 import *
|
||||
from .modeling_gemma4 import *
|
||||
from .processing_gemma4 import *
|
||||
from .video_processing_gemma4 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
@@ -0,0 +1,352 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.type_validators import interval
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4AudioConfig(PreTrainedConfig):
|
||||
r"""
|
||||
subsampling_conv_channels (`list[int]`, defaults to `[128, 32]`):
|
||||
Channel sizes for the convolutional layers in the Sub-sample Convolution Projection.
|
||||
residual_weight (`float`, defaults to `0.5`):
|
||||
Scaling applied to hidden_states prior to combining with the residual in the feedforward.
|
||||
attention_chunk_size (`int`, defaults to `12`):
|
||||
The sub-sequence size for attention processing.
|
||||
attention_context_left (`int`, defaults to `13`):
|
||||
The leftward context size for the attention chunk.
|
||||
attention_context_right (`int`, defaults to `0`):
|
||||
The rightward context size for the attention chunk.
|
||||
attention_logit_cap (`float`, defaults to `50.0`):
|
||||
Cap applied to attention weights.
|
||||
attention_invalid_logits_value (`float`, defaults to `1e-9`):
|
||||
Value to use for invalid logits in attention.
|
||||
use_clipped_linears (`bool`, defaults to `True`):
|
||||
If true, apply clipping to the Linear layers, drawing bounds from the model checkpoint.
|
||||
gradient_clipping (`float`, defaults to `1e10`):
|
||||
Clipping value used to stabilize extremely large gradient values.
|
||||
output_proj_dims (`int`, defaults to `1536`):
|
||||
Dimension of the final linear projection from `hidden_size` to the model's output.
|
||||
"""
|
||||
|
||||
model_type = "gemma4_audio"
|
||||
|
||||
hidden_size: int = 1024
|
||||
num_hidden_layers: int = 12
|
||||
num_attention_heads: int = 8
|
||||
hidden_act: str = "silu"
|
||||
|
||||
# subsampling parameters
|
||||
subsampling_conv_channels: list[int] | tuple[int, int] = (128, 32)
|
||||
|
||||
# conformer parameters
|
||||
conv_kernel_size: int = 5
|
||||
residual_weight: float = 0.5
|
||||
attention_chunk_size: int = 12
|
||||
attention_context_left: int = 13
|
||||
attention_context_right: int = 0
|
||||
attention_logit_cap: float = 50.0
|
||||
attention_invalid_logits_value: float = -1.0e9
|
||||
|
||||
use_clipped_linears: bool = True
|
||||
rms_norm_eps: float = 1e-6
|
||||
gradient_clipping: float = 1e10
|
||||
output_proj_dims: int = 1536
|
||||
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
# JSON serialization converts tuples to lists, convert back
|
||||
if isinstance(self.subsampling_conv_channels, tuple):
|
||||
self.subsampling_conv_channels = list(self.subsampling_conv_channels)
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4TextConfig(PreTrainedConfig):
|
||||
r"""
|
||||
use_bidirectional_attention (`str`, *optional*):
|
||||
Controls bidirectional attention behavior. When set to `"vision"`, vision tokens
|
||||
attend bidirectionally while text tokens use causal attention. When set to `"all"`,
|
||||
all tokens use bidirectional attention.
|
||||
vocab_size_per_layer_input (`int`, defaults to 262144):
|
||||
Vocabulary size for the per-layer input embeddings (PLE). Used by models with
|
||||
per-layer residual streams where a smaller embedding is added at each decoder layer.
|
||||
hidden_size_per_layer_input (`int`, defaults to 256):
|
||||
Per-layer hidden dimension for the PLE system. The actual embedding weight has shape
|
||||
`[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
|
||||
because all layers are packed into a single table. See the [Gemma4](https://huggingface.co/docs/transformers/main/en/model_doc/gemma4#per-layer-embeddings-ple) docs
|
||||
for a description of the full PLE pipeline.
|
||||
num_global_key_value_heads (`int`, *optional*):
|
||||
Number of key-value heads for global (full) attention layers. If `None`, defaults
|
||||
to `num_key_value_heads`.
|
||||
global_head_dim (`int`, defaults to 512):
|
||||
Dimension of each attention head in global (full) attention layers.
|
||||
attention_k_eq_v (`bool`, defaults to `False`):
|
||||
Whether keys and values share the same projection weights. When `True`, the key
|
||||
projection output is reused as the value projection.
|
||||
num_kv_shared_layers (`int`, defaults to 0):
|
||||
Number of consecutive decoder layers that share the same key-value projections.
|
||||
A value of 0 means no sharing (each layer has independent KV projections).
|
||||
enable_moe_block (`bool`, defaults to `False`):
|
||||
Whether to enable Mixture-of-Experts (MoE) blocks in the decoder layers. When
|
||||
`True`, eligible layers will use a sparse MoE feed-forward network.
|
||||
use_double_wide_mlp (`bool`, defaults to `False`):
|
||||
Whether to use a double-width MLP with fused gate and up projections.
|
||||
top_k_experts (`int`, *optional*):
|
||||
Number of experts activated per token in MoE layers. Only used when
|
||||
`enable_moe_block=True`.
|
||||
moe_intermediate_size (`int`, *optional*):
|
||||
Intermediate (hidden) size of each expert's feed-forward network in MoE layers.
|
||||
Only used when `enable_moe_block=True`.
|
||||
"""
|
||||
|
||||
model_type = "gemma4_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
|
||||
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.experts.gate_up_proj": "packed_colwise",
|
||||
"layers.*.experts.down_proj": "rowwise",
|
||||
"layers.*.experts": "moe_tp_experts",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
vocab_size: int = 262_144
|
||||
hidden_size: int = 2304
|
||||
intermediate_size: int = 9216
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 4
|
||||
head_dim: int = 256
|
||||
hidden_activation: str = "gelu_pytorch_tanh"
|
||||
max_position_embeddings: int = 131_072
|
||||
initializer_range: float = 0.02
|
||||
rms_norm_eps: float = 1e-6
|
||||
use_cache: bool = True
|
||||
pad_token_id: int | None = 0
|
||||
eos_token_id: int | list[int] | None = 1
|
||||
bos_token_id: int | None = 2
|
||||
tie_word_embeddings: bool = True
|
||||
rope_parameters: dict | None = None
|
||||
attention_bias: bool = False
|
||||
attention_dropout: int | float | None = 0.0
|
||||
sliding_window: int = 512
|
||||
layer_types: list[str] | None = None
|
||||
final_logit_softcapping: float | None = None
|
||||
use_bidirectional_attention: Literal["all", "vision"] | None = None
|
||||
vocab_size_per_layer_input: int = 262_144
|
||||
hidden_size_per_layer_input: int = 256
|
||||
num_global_key_value_heads: int | None = None
|
||||
global_head_dim: int = 512
|
||||
attention_k_eq_v: bool = False
|
||||
num_kv_shared_layers: int = 0
|
||||
enable_moe_block: bool = False
|
||||
use_double_wide_mlp: bool = False
|
||||
num_experts: int | None = None
|
||||
top_k_experts: int | None = None
|
||||
moe_intermediate_size: int | None = None
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.use_bidirectional_attention == "all":
|
||||
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
||||
|
||||
if self.layer_types is None:
|
||||
sliding_window_pattern = 6 # by default 5:1
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
if self.layer_types and (last_layer_type := self.layer_types[-1]) != "full_attention":
|
||||
logger.warning(
|
||||
f"Last layer must use `full_attention`, but got `{last_layer_type}`. Forcing last layer to `full_attention`."
|
||||
)
|
||||
self.layer_types[-1] = "full_attention"
|
||||
|
||||
default_rope_params: dict[Literal["full_attention", "sliding_attention"] : dict[str, Any]] = {
|
||||
"sliding_attention": {"rope_type": "default", "rope_theta": 10_000.0},
|
||||
"full_attention": {"rope_type": "proportional", "partial_rotary_factor": 0.25, "rope_theta": 1_000_000.0},
|
||||
}
|
||||
if self.rope_parameters is None:
|
||||
self.rope_parameters = default_rope_params
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
def convert_rope_params_to_dict(self, **kwargs):
|
||||
# No need to handle BC for new models, because they have no old-format `rope_scaling`
|
||||
return kwargs
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4VisionConfig(PreTrainedConfig):
|
||||
r"""
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
position_embedding_size (`int`, defaults to 10240):
|
||||
Maximum number of position embeddings for the vision encoder. Controls the size of
|
||||
the learned 2D position embedding table used by the patch embedder.
|
||||
use_clipped_linears (`bool`, defaults to `False`):
|
||||
Whether to use weight-clipped linear layers. When enabled, linear layer weights are
|
||||
clamped to a fixed range during the forward pass to improve numerical stability.
|
||||
standardize (`bool`, defaults to `False`):
|
||||
If true, applies a bias and scale to the soft tokens returned from the pooler.
|
||||
"""
|
||||
|
||||
model_type = "gemma4_vision"
|
||||
base_model_tp_plan = {
|
||||
"encoder.layers.*.self_attn.q_proj": "colwise",
|
||||
"encoder.layers.*.self_attn.k_proj": "colwise",
|
||||
"encoder.layers.*.self_attn.v_proj": "colwise",
|
||||
"encoder.layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
|
||||
"encoder.layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
|
||||
"encoder.layers.*.self_attn.o_proj": "rowwise",
|
||||
"encoder.layers.*.mlp.gate_proj": "colwise",
|
||||
"encoder.layers.*.mlp.up_proj": "colwise",
|
||||
"encoder.layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
default_theta = 100.0
|
||||
|
||||
hidden_size: int = 768
|
||||
intermediate_size: int = 3072
|
||||
num_hidden_layers: int = 16
|
||||
num_attention_heads: int = 12
|
||||
num_key_value_heads: int = 12
|
||||
head_dim: int = 64
|
||||
hidden_activation: str = "gelu_pytorch_tanh"
|
||||
rms_norm_eps: float = 1e-6
|
||||
max_position_embeddings: int = 131_072
|
||||
attention_bias: bool | None = False
|
||||
attention_dropout: float | None = 0.0
|
||||
rope_parameters: dict | None = None
|
||||
pooling_kernel_size: int = 3
|
||||
patch_size: int = 16
|
||||
position_embedding_size: int = 10 * 1024
|
||||
use_clipped_linears: bool = False
|
||||
standardize: bool = False
|
||||
initializer_range: float = 0.02
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.rope_parameters is None:
|
||||
self.rope_parameters = {"rope_type": "default", "rope_theta": 100.0}
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4Config(PreTrainedConfig):
|
||||
r"""
|
||||
boi_token_id (`int`, *optional*, defaults to 255999):
|
||||
The begin-of-image token index to wrap the image prompt.
|
||||
eoi_token_id (`int`, *optional*, defaults to 258882):
|
||||
The end-of-image token index to wrap the image prompt.
|
||||
boa_token_id (`int`, *optional*, defaults to 256000):
|
||||
The begin-of-audio token index to wrap the audio prompt.
|
||||
eoa_token_index (`int`, *optional*, defaults to 258883):
|
||||
The end-of-audio token index to wrap the audio prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
>>> Gemma4AudioConfig,
|
||||
>>> Gemma4Config,
|
||||
>>> Gemma4ForConditionalGeneration,
|
||||
>>> Gemma4TextConfig,
|
||||
>>> Gemma4VisionConfig,
|
||||
>>> )
|
||||
|
||||
>>> # Initializing a Gemma 4 Audio config.
|
||||
>>> audio_config = Gemma4AudioConfig()
|
||||
|
||||
>>> # Initializing a Gemma 4 Text config.
|
||||
>>> text_config = Gemma4TextConfig()
|
||||
|
||||
>>> # Initializing a Gemma 4 vision config.
|
||||
>>> vision_config = Gemma4VisionConfig()
|
||||
|
||||
>>> # Initializing a Gemma 4 config similar to google/gemma-4-e2b-it
|
||||
>>> configuration = Gemma4Config(text_config, vision_config, audio_config)
|
||||
|
||||
>>> # Initializing a model from the google/gemma-4-e2b-it configuration
|
||||
>>> model = Gemma4ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma4"
|
||||
sub_configs = {
|
||||
"text_config": Gemma4TextConfig,
|
||||
"vision_config": Gemma4VisionConfig,
|
||||
"audio_config": Gemma4AudioConfig,
|
||||
}
|
||||
|
||||
text_config: Gemma4TextConfig | dict[str, Any] | None = None
|
||||
vision_config: Gemma4VisionConfig | dict[str, Any] | None = None
|
||||
audio_config: Gemma4AudioConfig | dict[str, Any] | None = None
|
||||
boi_token_id: int | None = 255_999
|
||||
eoi_token_id: int | None = 258_882
|
||||
image_token_id: int | None = 258_880
|
||||
video_token_id: int | None = 258_884
|
||||
boa_token_id: int | None = 256_000
|
||||
eoa_token_index: int | None = 258_883
|
||||
audio_token_id: int | None = 258_881
|
||||
initializer_range: float | None = 0.02
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.text_config is None:
|
||||
self.text_config = Gemma4TextConfig()
|
||||
logger.info("text_config is None. Using default Gemma4TextConfig.")
|
||||
elif isinstance(self.text_config, dict):
|
||||
self.text_config = Gemma4TextConfig(**self.text_config)
|
||||
|
||||
if self.vision_config is None:
|
||||
logger.info("vision_config is None. Gemma4Model.vision_tower will not be initialized.")
|
||||
if isinstance(self.vision_config, dict):
|
||||
self.vision_config = Gemma4VisionConfig(**self.vision_config)
|
||||
|
||||
if self.audio_config is None:
|
||||
logger.info("audio_config is None. Gemma4Model.audio_tower will not be initialized.")
|
||||
if isinstance(self.audio_config, dict):
|
||||
self.audio_config = Gemma4AudioConfig(**self.audio_config)
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Gemma4AudioConfig", "Gemma4Config", "Gemma4TextConfig", "Gemma4VisionConfig"]
|
||||
@@ -0,0 +1,298 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...audio_utils import mel_filter_bank, window_function
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
|
||||
"""A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
|
||||
if array.ndim != 2:
|
||||
raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
|
||||
if dimension != -1 and dimension != array.ndim - 1:
|
||||
raise ValueError("This unfold implementation only supports unfolding the last dimension.")
|
||||
|
||||
batch_size, original_length = array.shape
|
||||
num_frames = (original_length - size) // step + 1
|
||||
|
||||
if num_frames <= 0:
|
||||
return np.zeros((batch_size, 0, size), dtype=array.dtype)
|
||||
|
||||
output_shape = (batch_size, num_frames, size)
|
||||
output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
|
||||
|
||||
return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
|
||||
|
||||
|
||||
class Gemma4AudioFeatureExtractor(SequenceFeatureExtractor):
|
||||
"""An audio feature extractor Universal Speech Models https://huggingface.co/papers/2303.01037.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 128):
|
||||
The feature dimension of the extracted features.
|
||||
sampling_rate (`int`, *optional*, defaults to 16000):
|
||||
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
Padding value used to pad the audio. Should correspond to silences.
|
||||
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the attention mask for the generated MEL spectrograms.
|
||||
frame_length_ms (`float`, *optional*, defaults to 20.0):
|
||||
The length of a frame in milliseconds.
|
||||
hop_length_ms (`float`, *optional*, defaults to 10.0):
|
||||
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
|
||||
min_frequency (`float`, *optional*, defaults to 0.0):
|
||||
The minimum frequency (in Hz) for the Mel filterbank.
|
||||
max_frequency (`float`, *optional*, defaults to 8000.0):
|
||||
The maximum frequency (in Hz) for the Mel filterbank.
|
||||
preemphasis (`float`, *optional*, defaults to 0.0):
|
||||
The preemphasis coefficient.
|
||||
preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use HTK-style preemphasis.
|
||||
fft_overdrive (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use FFT overdrive.
|
||||
dither (`float`, *optional*, defaults to 0.0):
|
||||
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
||||
E.g. use 0.0001 to add dithering with a normal distribution centered
|
||||
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
|
||||
The value 0.0 means no dithering.
|
||||
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
|
||||
the high log_mel_fbank values for signals with hard-zero sections,
|
||||
when VAD cutoff is present in the signal.
|
||||
input_scale_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor applied to the input waveform.
|
||||
mel_floor (`float`, *optional*, defaults to 0.001):
|
||||
Minimum value for Mel spectrograms to avoid log(0).
|
||||
per_bin_mean (`Optional[Sequence[float]]`, *optional*):
|
||||
Mean values for per-bin normalization.
|
||||
per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
|
||||
Standard deviation values for per-bin normalization.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_features", "input_features_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size: int = 128,
|
||||
sampling_rate: int = 16_000,
|
||||
padding_value: float = 0.0,
|
||||
return_attention_mask: bool = True,
|
||||
frame_length_ms: float = 20.0,
|
||||
hop_length_ms: float = 10.0,
|
||||
min_frequency: float = 0.0,
|
||||
max_frequency: float = 8000.0,
|
||||
preemphasis: float = 0.0,
|
||||
preemphasis_htk_flavor: bool = True,
|
||||
fft_overdrive: bool = False,
|
||||
dither: float = 0.0,
|
||||
input_scale_factor: float = 1.0,
|
||||
mel_floor: float = 1e-3,
|
||||
per_bin_mean: Sequence[float] | None = None,
|
||||
per_bin_stddev: Sequence[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
feature_size=feature_size,
|
||||
sampling_rate=sampling_rate,
|
||||
padding_value=padding_value,
|
||||
return_attention_mask=return_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.min_frequency = min_frequency
|
||||
self.max_frequency = max_frequency
|
||||
self.preemphasis = preemphasis
|
||||
self.preemphasis_htk_flavor = preemphasis_htk_flavor
|
||||
self.fft_overdrive = fft_overdrive
|
||||
self.dither = dither
|
||||
self.input_scale_factor = input_scale_factor
|
||||
self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
|
||||
self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
|
||||
self.mel_floor = np.array(mel_floor, dtype=np.float64)
|
||||
|
||||
fft_length = 2 ** math.ceil(math.log2(self.frame_length))
|
||||
if self.fft_overdrive:
|
||||
fft_length *= 2
|
||||
self.fft_length = fft_length
|
||||
|
||||
# Use periodic Hann window, matching sl.STFT default (signal.hann_window)
|
||||
# For even frame_length: window[n] = 0.5 - 0.5 * cos(2*pi*n / frame_length)
|
||||
self.window = window_function(self.frame_length).astype(np.float32)
|
||||
|
||||
# Use HuggingFace's mel_filter_bank for compatibility.
|
||||
# Suppress the expected warning about all-zero upper mel filters;
|
||||
# with fft_length=512 (257 bins) and 128 mel filters the uppermost
|
||||
# triangular filter falls between frequency bins, which is harmless.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
self.mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=self.fft_length // 2 + 1,
|
||||
num_mel_filters=feature_size,
|
||||
min_frequency=min_frequency,
|
||||
max_frequency=max_frequency,
|
||||
sampling_rate=self.sampling_rate,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
)
|
||||
|
||||
if per_bin_mean is not None:
|
||||
self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
|
||||
else:
|
||||
self.per_bin_mean = None
|
||||
|
||||
if per_bin_stddev is not None:
|
||||
self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
|
||||
else:
|
||||
self.per_bin_stddev = None
|
||||
|
||||
def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
""""""
|
||||
if waveform.ndim == 1: # If single waveform, add batch dimension
|
||||
waveform = np.expand_dims(waveform, axis=0)
|
||||
|
||||
if self.dither > 0.0:
|
||||
waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
|
||||
|
||||
if self.input_scale_factor != 1.0:
|
||||
waveform = waveform * self.input_scale_factor
|
||||
|
||||
# Semicausal time padding: prepend frame_length // 2 zeros so that the
|
||||
# first STFT frame is centered at t=0, matching sl.STFT(time_padding='semicausal').
|
||||
pad_left = self.frame_length // 2
|
||||
waveform = np.pad(waveform, ((0, 0), (pad_left, 0)), mode="constant")
|
||||
attention_mask = np.pad(attention_mask, (pad_left, 0), mode="constant", constant_values=0)
|
||||
|
||||
frame_size_for_unfold = self.frame_length + 1
|
||||
|
||||
# NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
|
||||
frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
|
||||
|
||||
if self.preemphasis > 0.0:
|
||||
if self.preemphasis_htk_flavor:
|
||||
first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
|
||||
rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
|
||||
frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
|
||||
else:
|
||||
frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
|
||||
else:
|
||||
frames = frames_to_process[..., :-1]
|
||||
|
||||
# Apply window, then RFFT. np.fft.rfft with n=fft_length implicitly
|
||||
# right-pads frames to fft_length.
|
||||
frames = frames * self.window # Broadcasting window
|
||||
stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
|
||||
|
||||
magnitude_spec = np.abs(stft)
|
||||
|
||||
mel_spec = np.matmul(magnitude_spec, self.mel_filters)
|
||||
log_mel_spec = np.log(mel_spec + self.mel_floor)
|
||||
|
||||
if self.per_bin_mean is not None:
|
||||
log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
|
||||
|
||||
if self.per_bin_stddev is not None:
|
||||
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
|
||||
|
||||
mel_spectrogram = log_mel_spec.squeeze(0)
|
||||
num_mel_frames = mel_spectrogram.shape[0]
|
||||
|
||||
# Build a frame-aware mask: a mel frame is valid only when every sample
|
||||
# in its analysis window [i*hop, i*hop + frame_size - 1] is real audio.
|
||||
# We check this by looking at the last sample of each frame's window.
|
||||
frame_end_indices = np.arange(num_mel_frames) * self.hop_length + frame_size_for_unfold - 1
|
||||
mask = attention_mask[frame_end_indices].astype(bool)
|
||||
return mel_spectrogram, mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
|
||||
padding: bool | str | PaddingStrategy = "longest",
|
||||
max_length: int | None = 480_000,
|
||||
truncation: bool = True,
|
||||
pad_to_multiple_of: int | None = 128,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
return_attention_mask: bool | None = True,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""Creates a batch of MEL spectrograms from the provided raw speech.
|
||||
|
||||
This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
|
||||
`transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
|
||||
carefully when selecting an audio feature extractor, especially with pre-trained models.
|
||||
|
||||
Args:
|
||||
raw_speech:
|
||||
The audio for which MEL spectrograms are created.
|
||||
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
|
||||
The padding strategy to use for batches of audio with different lengths.
|
||||
max_length (`int`, *optional*, defaults to 480000):
|
||||
If provided, defines the maximum length of the audio to allow. Audio longer than this will be
|
||||
truncated if `truncation=True`.
|
||||
truncation (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to truncate audio above `max_length`.
|
||||
pad_to_multiple_of (`int`, *optional*, defaults to 128):
|
||||
When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
|
||||
return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
|
||||
The type of tensors to return (e.g., NumPy, or Torch).
|
||||
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the attention mask for the generated MEL spectrograms.
|
||||
"""
|
||||
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
|
||||
is_batched = is_batched_numpy or is_batched_sequence
|
||||
|
||||
if is_batched:
|
||||
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
|
||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = np.asarray(raw_speech)
|
||||
|
||||
if not is_batched: # always return a batch
|
||||
raw_speech = [np.asarray([raw_speech])]
|
||||
|
||||
batched_speech = self.pad(
|
||||
BatchFeature({"input_features": raw_speech}),
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
prepared_speech = []
|
||||
prepared_speech_mask = []
|
||||
for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
|
||||
speech, mask = self._extract_spectrogram(speech.T, mask)
|
||||
prepared_speech.append(speech.astype(np.float32))
|
||||
prepared_speech_mask.append(mask)
|
||||
|
||||
prepared_speech = [speech * mask[..., None] for speech, mask in zip(prepared_speech, prepared_speech_mask)]
|
||||
|
||||
return BatchFeature(
|
||||
{"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Gemma4AudioFeatureExtractor"]
|
||||
@@ -0,0 +1,220 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
|
||||
from ...image_processing_backends import TorchvisionBackend
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, PILImageResampling
|
||||
from ...processing_utils import ImagesKwargs, Unpack
|
||||
from ...utils import TensorType, auto_docstring, logging
|
||||
from .image_processing_pil_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.siglip2.image_processing_siglip2.convert_image_to_patches
|
||||
def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
|
||||
"""
|
||||
Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape
|
||||
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
num_channels, image_height, image_width = image.shape
|
||||
num_patches_height = image_height // patch_size
|
||||
num_patches_width = image_width // patch_size
|
||||
patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
|
||||
patched_image = patched_image.permute(1, 3, 2, 4, 0)
|
||||
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
|
||||
return patched_image
|
||||
|
||||
|
||||
# Adopted from Siglip2 (mask -> position ids)
|
||||
def pad_along_first_dim(
|
||||
image: "torch.Tensor", positions: "torch.Tensor", target_length: int
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Pad the tensor along the first dimension.
|
||||
"""
|
||||
current_length = image.shape[0]
|
||||
padding_length = target_length - current_length
|
||||
if padding_length > 0:
|
||||
padding = [0, 0] * (image.ndim - 1) + [0, padding_length]
|
||||
pos_padding = (0, 0, 0, padding_length)
|
||||
image = torch.nn.functional.pad(image, padding, mode="constant", value=0)
|
||||
positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1)
|
||||
return image, positions
|
||||
|
||||
|
||||
class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False):
|
||||
"""
|
||||
patch_size (`int`, *optional*):
|
||||
Size of each image patch in pixels.
|
||||
max_soft_tokens (`int`, *optional*):
|
||||
Maximum number of soft (vision) tokens per image.
|
||||
Must be one of {70, 140, 280, 560, 1120}.
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
"""
|
||||
|
||||
patch_size: int
|
||||
max_soft_tokens: int
|
||||
pooling_kernel_size: int
|
||||
|
||||
|
||||
@auto_docstring(custom_intro="Constructs a Gemma4 image processor.")
|
||||
class Gemma4ImageProcessor(TorchvisionBackend):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = [0.0, 0.0, 0.0]
|
||||
image_std = [1.0, 1.0, 1.0]
|
||||
size = None
|
||||
default_to_square = True
|
||||
do_convert_rgb = True
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = False
|
||||
patch_size = 16
|
||||
max_soft_tokens = 280
|
||||
pooling_kernel_size = 3
|
||||
valid_kwargs = Gemma4ImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
|
||||
|
||||
def _validate_preprocess_kwargs(self, **kwargs):
|
||||
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
|
||||
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
|
||||
# parameter. Temporarily disable do_resize so the base validation
|
||||
# doesn't require `size` to be set.
|
||||
kwargs["do_resize"] = False
|
||||
super()._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
def aspect_ratio_preserving_resize(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
resample: F.InterpolationMode,
|
||||
) -> torch.Tensor:
|
||||
height, width = image.shape[-2], image.shape[-1]
|
||||
target_height, target_width = get_aspect_ratio_preserving_size(
|
||||
height=height,
|
||||
width=width,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
|
||||
if target_height == height and target_width == width:
|
||||
return image
|
||||
|
||||
return F.resize(
|
||||
image,
|
||||
size=[target_height, target_width],
|
||||
interpolation=resample,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs: Unpack[Gemma4ImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
resample: "PILImageResampling | F.InterpolationMode | int | None",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
return_tensors: str | TensorType | None,
|
||||
patch_size: int | None = None,
|
||||
max_soft_tokens: int | None = None,
|
||||
pooling_kernel_size: int | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
|
||||
|
||||
# Compute max_patches from max_soft_tokens and pooling_kernel_size
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
# Process each image individually: resize, rescale/normalize, patchify, pad.
|
||||
# Images have different aspect ratios and thus different resized dimensions,
|
||||
# so patchification and padding must happen per-image before stacking.
|
||||
pixel_values = []
|
||||
position_ids = []
|
||||
num_soft_tokens_per_image = []
|
||||
|
||||
for image in images:
|
||||
# Step 1: Aspect-ratio-preserving resize
|
||||
if do_resize:
|
||||
image = self.aspect_ratio_preserving_resize(
|
||||
image=image,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
# Step 2: Rescale pixel values (typically to [0, 1]) and optionally identity normalize
|
||||
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||
|
||||
# Step 3: Patchify the image
|
||||
# (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
|
||||
patch_height = image.shape[-2] // patch_size
|
||||
patch_width = image.shape[-1] // patch_size
|
||||
patches = convert_image_to_patches(image, patch_size)
|
||||
num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2)
|
||||
|
||||
# Step 5: Compute position IDs
|
||||
device = image.device
|
||||
patch_grid = torch.meshgrid(
|
||||
torch.arange(patch_width, device=device),
|
||||
torch.arange(patch_height, device=device),
|
||||
indexing="xy",
|
||||
)
|
||||
stacked_grid = torch.stack(patch_grid, dim=-1)
|
||||
real_positions = stacked_grid.reshape(patches.shape[0], 2)
|
||||
|
||||
# Step 6. Pad pacthes and positions to `max_patches`
|
||||
patches, positions = pad_along_first_dim(patches, real_positions, max_patches)
|
||||
pixel_values.append(patches)
|
||||
position_ids.append(positions)
|
||||
|
||||
# Stack into batch tensors
|
||||
pixel_values = torch.stack(pixel_values, dim=0) # (batch, max_patches, patch_pixels)
|
||||
position_ids = torch.stack(position_ids, dim=0) # (batch, max_patches, 2)
|
||||
|
||||
data = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_position_ids": position_ids,
|
||||
"num_soft_tokens_per_image": num_soft_tokens_per_image,
|
||||
}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma4ImageProcessor"]
|
||||
@@ -0,0 +1,278 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_backends import PilBackend
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_transforms import resize
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, Unpack
|
||||
from ...utils import TensorType, auto_docstring, is_vision_available, logging
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120)
|
||||
|
||||
|
||||
def get_aspect_ratio_preserving_size(
|
||||
height: int,
|
||||
width: int,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Image is resized to preserve aspect ratio so it fits within the patch budget.
|
||||
Target dimensions are the largest that:
|
||||
1) Produce at most `max_patches` patches when patchified with `patch_size`
|
||||
2) Have height and width divisible by `pooling_kernel_size * patch_size`
|
||||
"""
|
||||
total_px = height * width
|
||||
target_px = max_patches * (patch_size**2)
|
||||
factor = math.sqrt(target_px / total_px)
|
||||
ideal_height = factor * height
|
||||
ideal_width = factor * width
|
||||
side_mult = pooling_kernel_size * patch_size
|
||||
|
||||
# Round down to nearest multiple of side_mult
|
||||
target_height = int(math.floor(ideal_height / side_mult)) * side_mult
|
||||
target_width = int(math.floor(ideal_width / side_mult)) * side_mult
|
||||
|
||||
# Handle edge cases where one or both dimensions round to 0
|
||||
if target_height == 0 and target_width == 0:
|
||||
raise ValueError(
|
||||
"Attempting to resize to a 0 x 0 image. Resized height should be divisble by "
|
||||
f"`pooling_kernel_size * patch_size`={pooling_kernel_size * patch_size}."
|
||||
)
|
||||
|
||||
max_side_length = (max_patches // pooling_kernel_size**2) * side_mult
|
||||
if target_height == 0:
|
||||
target_height = side_mult
|
||||
target_width = min(
|
||||
int(math.floor(width / height)) * side_mult,
|
||||
max_side_length,
|
||||
)
|
||||
elif target_width == 0:
|
||||
target_width = side_mult
|
||||
target_height = min(
|
||||
int(math.floor(height / width)) * side_mult,
|
||||
max_side_length,
|
||||
)
|
||||
|
||||
if target_height * target_width > target_px:
|
||||
raise ValueError(
|
||||
f"Resizing [{height}x{width}] to [{target_height}x{target_width}] "
|
||||
f"but this exceeds {max_patches} patches with patch_size {patch_size}"
|
||||
)
|
||||
|
||||
return target_height, target_width
|
||||
|
||||
|
||||
# Copied from transformers.models.siglip2.image_processing_pil_siglip2.convert_image_to_patches
|
||||
def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""
|
||||
Convert 3D array image of shape (num_channels, image_height, image_width) into 2D array of patches of shape
|
||||
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
num_channels, image_height, image_width = image.shape
|
||||
num_patches_height = image_height // patch_size
|
||||
num_patches_width = image_width // patch_size
|
||||
patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
|
||||
patched_image = patched_image.transpose(1, 3, 2, 4, 0)
|
||||
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
|
||||
return patched_image
|
||||
|
||||
|
||||
# Adopted from Siglip2 (mask -> position ids)
|
||||
def pad_along_first_dim(image: np.ndarray, positions: np.ndarray, target_length: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Pad the image along the first dimension.
|
||||
"""
|
||||
current_length = image.shape[0]
|
||||
padding_length = target_length - current_length
|
||||
if padding_length > 0:
|
||||
paddings = [(0, padding_length)] + [(0, 0)] * (image.ndim - 1)
|
||||
pos_paddings = [(0, padding_length), (0, 0)]
|
||||
image = np.pad(image, paddings, mode="constant", constant_values=0)
|
||||
positions = np.pad(positions, pos_paddings, mode="constant", constant_values=-1)
|
||||
return image, positions
|
||||
|
||||
|
||||
class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False):
|
||||
"""
|
||||
patch_size (`int`, *optional*):
|
||||
Size of each image patch in pixels.
|
||||
max_soft_tokens (`int`, *optional*):
|
||||
Maximum number of soft (vision) tokens per image.
|
||||
Must be one of {70, 140, 280, 560, 1120}.
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
"""
|
||||
|
||||
patch_size: int
|
||||
max_soft_tokens: int
|
||||
pooling_kernel_size: int
|
||||
|
||||
|
||||
@auto_docstring(custom_intro="Constructs a Gemma4 image processor.")
|
||||
class Gemma4ImageProcessorPil(PilBackend):
|
||||
valid_kwargs = Gemma4ImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"]
|
||||
|
||||
do_resize = True
|
||||
resample = PILImageResampling.BICUBIC
|
||||
do_rescale = True
|
||||
rescale_factor = 1 / 255
|
||||
do_normalize = False
|
||||
image_mean = [0.0, 0.0, 0.0]
|
||||
image_std = [1.0, 1.0, 1.0]
|
||||
do_convert_rgb = True
|
||||
patch_size = 16
|
||||
max_soft_tokens = 280
|
||||
pooling_kernel_size = 3
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
|
||||
|
||||
def _validate_preprocess_kwargs(self, **kwargs):
|
||||
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
|
||||
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
|
||||
# parameter. Temporarily disable do_resize so the base validation
|
||||
# doesn't require `size` to be set.
|
||||
kwargs["do_resize"] = False
|
||||
super()._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs: Unpack[Gemma4ImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def aspect_ratio_preserving_resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
height, width = image.shape[-2], image.shape[-1]
|
||||
target_height, target_width = get_aspect_ratio_preserving_size(
|
||||
height=height,
|
||||
width=width,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
|
||||
if target_height == height and target_width == width:
|
||||
return image
|
||||
|
||||
return resize(
|
||||
image,
|
||||
size=(target_height, target_width),
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[np.ndarray],
|
||||
do_resize: bool,
|
||||
resample: "PILImageResampling | int | None",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
return_tensors: str | TensorType | None,
|
||||
max_soft_tokens: int | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_kernel_size: int | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
|
||||
|
||||
# Compute max_patches from max_soft_tokens and pooling_kernel_size
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
# Process each image individually: resize, rescale/normalize, patchify, pad.
|
||||
# Images have different aspect ratios and thus different resized dimensions,
|
||||
# so patchification and padding must happen per-image before stacking.
|
||||
pixel_values = []
|
||||
position_ids = []
|
||||
num_soft_tokens_per_image = []
|
||||
|
||||
for image in images:
|
||||
# Step 1: Aspect-ratio-preserving resize
|
||||
if do_resize:
|
||||
image = self.aspect_ratio_preserving_resize(
|
||||
image=image,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
# Step 2: Rescale pixel values from [0, 255] to [0, 1]
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor)
|
||||
|
||||
# Step 3: Identity normalization because Gemma4 was trained with pixels in [0, 1]
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
|
||||
# Step 4: Patchify the image
|
||||
# image is (C, H, W) numpy array; add batch dimension for reshape
|
||||
# (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
|
||||
patches = convert_image_to_patches(image, patch_size)
|
||||
num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2)
|
||||
|
||||
# Step 5: Compute position IDs
|
||||
patch_height = image.shape[-2] // patch_size
|
||||
patch_width = image.shape[-1] // patch_size
|
||||
grid_x, grid_y = np.meshgrid(np.arange(patch_width), np.arange(patch_height), indexing="xy")
|
||||
real_positions = np.stack([grid_x, grid_y], axis=-1).reshape(patches.shape[0], 2)
|
||||
|
||||
patches, positions = pad_along_first_dim(patches, real_positions, max_patches)
|
||||
|
||||
pixel_values.append(patches)
|
||||
position_ids.append(positions)
|
||||
|
||||
# Stack into batch arrays and convert to tensors
|
||||
pixel_values = np.stack(pixel_values, axis=0) # (batch, max_patches, patch_pixels)
|
||||
position_ids = np.stack(position_ids, axis=0) # (batch, max_patches, 2)
|
||||
|
||||
data = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_position_ids": position_ids,
|
||||
"num_soft_tokens_per_image": num_soft_tokens_per_image,
|
||||
}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma4ImageProcessorPil"]
|
||||
@@ -0,0 +1,723 @@
|
||||
# === HEADER (license + imports) ===
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma4.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ... import initialization as init
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_experts_implementation, use_kernelized_func
|
||||
from ...masking_utils import (
|
||||
create_bidirectional_mask,
|
||||
create_causal_mask,
|
||||
create_masks_for_generate,
|
||||
create_sliding_window_causal_mask,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
TransformersKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_accelerate_available,
|
||||
torch_compilable_check,
|
||||
)
|
||||
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
||||
from ...utils.output_capturing import OutputRecorder, capture_outputs
|
||||
from ..auto.modeling_auto import AutoModel
|
||||
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Gemma4 outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
|
||||
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Gemma4 outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
...
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Gemma4 causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class Gemma4CausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
...
|
||||
|
||||
@dataclass
|
||||
@auto_docstring
|
||||
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
|
||||
r"""
|
||||
attention_mask (`torch.BoolTensor`, *optional*):
|
||||
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
|
||||
"""
|
||||
|
||||
attention_mask: torch.BoolTensor | None = None
|
||||
|
||||
|
||||
class Gemma4ClippableLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
...
|
||||
|
||||
class Gemma4RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.with_scale = with_scale
|
||||
|
||||
if self.with_scale:
|
||||
self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
|
||||
|
||||
def _norm(self, hidden_states: torch.Tensor):
|
||||
mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
|
||||
# Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
|
||||
return hidden_states * torch.pow(mean_squared, -0.5)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioRelPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal relative positional encoding for the audio encoder.
|
||||
|
||||
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
|
||||
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
|
||||
"""
|
||||
|
||||
inv_timescales: torch.Tensor
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.context_size = (
|
||||
config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
|
||||
...
|
||||
|
||||
class Gemma4AudioAttention(nn.Module):
|
||||
"""Chunked local attention with relative position bias"""
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_logits_soft_cap = config.attention_logit_cap
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.q_scale = (self.head_dim**-0.5) / math.log(2)
|
||||
self.k_scale = math.log(1 + math.e) / math.log(2)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, norm_eps):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjection(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=1,
|
||||
out_channels=config.subsampling_conv_channels[0],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=config.subsampling_conv_channels[0],
|
||||
out_channels=config.subsampling_conv_channels[1],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
|
||||
...
|
||||
|
||||
class Gemma4AudioFeedForward(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
|
||||
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
|
||||
|
||||
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.gradient_clipping = config.gradient_clipping
|
||||
self.post_layer_scale = config.residual_weight
|
||||
...
|
||||
|
||||
class Gemma4AudioCausalConv1d(nn.Conv1d):
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_channels: int,
|
||||
# out_channels: int,
|
||||
# kernel_size: int,
|
||||
# # cache_key: str,
|
||||
# stride: int = 1,
|
||||
# dilation: int = 1,
|
||||
# bias: bool = True,
|
||||
# ):
|
||||
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
|
||||
# self.cache_key = cache_key
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioLightConv1d(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
|
||||
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
|
||||
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
|
||||
in_channels=config.hidden_size,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=config.conv_kernel_size,
|
||||
groups=config.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
...
|
||||
|
||||
class Gemma4AudioLayer(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.feed_forward1 = Gemma4AudioFeedForward(config)
|
||||
self.feed_forward2 = Gemma4AudioFeedForward(config)
|
||||
self.self_attn = Gemma4AudioAttention(config, layer_idx)
|
||||
self.lconv1d = Gemma4AudioLightConv1d(config)
|
||||
|
||||
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_out = Gemma4RMSNorm(config.hidden_size)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4VisionPatchEmbedder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.patch_size = config.patch_size
|
||||
self.position_embedding_size = config.position_embedding_size
|
||||
|
||||
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
|
||||
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
|
||||
|
||||
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Prepare patch positions map for matmul with positon embedding table."""
|
||||
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
|
||||
...
|
||||
|
||||
class Gemma4VisionPooler(nn.Module):
|
||||
"""Scaling and optional spatial pooling for vision encodings"""
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.root_hidden_size = self.hidden_size**0.5
|
||||
|
||||
def _avg_pool_by_positions(
|
||||
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
2D spatial pooling according to patch positions.
|
||||
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
|
||||
...
|
||||
|
||||
class Gemma4VisionMLP(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
...
|
||||
|
||||
class Gemma4VisionRotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig, device=None):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
|
||||
self.rope_type = self.config.rope_parameters["rope_type"]
|
||||
rope_init_fn: Callable = self.compute_default_rope_parameters
|
||||
if self.rope_type != "default":
|
||||
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
...
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): The tensor to embed.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
...
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
...
|
||||
|
||||
def apply_multidimensional_rope(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
unsqueeze_dim: int = 2,
|
||||
) -> torch.Tensor:
|
||||
"""Applies multidimensional RoPE to inputs.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): The tensor to embed.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
...
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class Gemma4VisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = 1.0
|
||||
self.attention_dropout = self.config.attention_dropout
|
||||
self.is_causal = False
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4VisionMLP(config)
|
||||
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
...
|
||||
|
||||
class Gemma4TextMLP(nn.Module):
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
|
||||
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
...
|
||||
|
||||
class Gemma4TextRotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.layer_types = set(config.layer_types)
|
||||
self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
|
||||
self.rope_type: dict[str, str] = {}
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
...
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class Gemma4TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.is_sliding = self.layer_type == "sliding_attention"
|
||||
self.sliding_window = config.sliding_window if self.is_sliding else None
|
||||
|
||||
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
|
||||
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
|
||||
...
|
||||
|
||||
@use_experts_implementation
|
||||
class Gemma4TextExperts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.moe_intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
def forward(
|
||||
...
|
||||
|
||||
class Gemma4TextRouter(nn.Module):
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.scalar_root_size = self.hidden_size**-0.5
|
||||
self.eps = config.rms_norm_eps
|
||||
|
||||
self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
|
||||
self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||
self.scale = nn.Parameter(torch.ones(self.hidden_size))
|
||||
self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
class Gemma4TextDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4TextMLP(config, layer_idx)
|
||||
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.register_buffer("layer_scalar", torch.ones(1))
|
||||
|
||||
...
|
||||
|
||||
class Gemma4TextScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.scalar_embed_scale = embed_scale
|
||||
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
||||
|
||||
|
||||
...
|
||||
|
||||
@auto_docstring
|
||||
class Gemma4PreTrainedModel(PreTrainedModel):
|
||||
config: Gemma4Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = None # override
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
|
||||
class Gemma4TextModel(Gemma4PreTrainedModel):
|
||||
config: Gemma4TextConfig
|
||||
input_modalities = ("text",)
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
|
||||
"hidden_states": Gemma4TextDecoderLayer,
|
||||
"attentions": Gemma4TextAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
|
||||
class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_gather_output"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
config: Gemma4TextConfig
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
self.model = Gemma4TextModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
# Grab the ones from the child
|
||||
...
|
||||
|
||||
def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable:
|
||||
"""
|
||||
This creates uni/bidirectional attention mask with sliding window.
|
||||
"""
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
left_window_size, right_window_size = sliding_window
|
||||
|
||||
dist = q_idx - kv_idx
|
||||
left_mask = (dist >= 0) & (dist < left_window_size)
|
||||
right_mask = (dist < 0) & (-dist < right_window_size)
|
||||
return left_mask | right_mask
|
||||
|
||||
return inner_mask
|
||||
...
|
||||
|
||||
class Gemma4AudioModel(Gemma4PreTrainedModel):
|
||||
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
|
||||
|
||||
config: Gemma4AudioConfig
|
||||
main_input_name = "input_features"
|
||||
base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Gemma4AudioLayer,
|
||||
"attentions": Gemma4AudioAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
...
|
||||
|
||||
class Gemma4VisionModel(Gemma4PreTrainedModel):
|
||||
"""The Gemma 4 Vision Encoder."""
|
||||
|
||||
config = Gemma4VisionConfig
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Gemma4VisionEncoderLayer,
|
||||
"attentions": Gemma4VisionAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__(config)
|
||||
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
|
||||
self.encoder = Gemma4VisionEncoder(config)
|
||||
self.pooler = Gemma4VisionPooler(config)
|
||||
...
|
||||
|
||||
class Gemma4MultimodalEmbedder(nn.Module):
|
||||
"""Embeds token ids or soft tokens for multimodal content into language model space."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
|
||||
text_config: Gemma4TextConfig,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
|
||||
self.eps = multimodal_config.rms_norm_eps
|
||||
self.text_hidden_size = text_config.hidden_size
|
||||
self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
|
||||
...
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
token_type_ids: torch.Tensor | None,
|
||||
image_group_ids: torch.Tensor | None,
|
||||
) -> Callable | None:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
"""
|
||||
# Do not return an additional mask in this case
|
||||
if token_type_ids is None:
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
seq_length = image_group_ids.shape[-1]
|
||||
...
|
||||
|
||||
def create_causal_mask_mapping(
|
||||
config: PreTrainedConfig,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_values: Cache | None,
|
||||
position_ids: torch.Tensor | None,
|
||||
mm_token_type_ids: torch.Tensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
is_training: bool = False,
|
||||
is_first_iteration: bool | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
|
||||
language modeling head.
|
||||
"""
|
||||
)
|
||||
class Gemma4Model(Gemma4PreTrainedModel):
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
||||
head.
|
||||
"""
|
||||
)
|
||||
class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
accepts_loss_kwargs = False
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
self.model = Gemma4Model(config)
|
||||
...
|
||||
|
||||
@@ -0,0 +1,563 @@
|
||||
# === HEADER (license + imports) ===
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ... import initialization as init
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...integrations import use_kernelized_func
|
||||
from ...masking_utils import (
|
||||
create_bidirectional_mask,
|
||||
create_causal_mask,
|
||||
create_masks_for_generate,
|
||||
create_sliding_window_causal_mask,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TransformersKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
torch_compilable_check,
|
||||
)
|
||||
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
||||
from ...utils.output_capturing import OutputRecorder, capture_outputs
|
||||
from ..auto.modeling_auto import AutoModel
|
||||
from ..gemma3.modeling_gemma3 import (
|
||||
Gemma3Attention,
|
||||
Gemma3DecoderLayer,
|
||||
Gemma3ForCausalLM,
|
||||
Gemma3MLP,
|
||||
Gemma3RotaryEmbedding,
|
||||
Gemma3TextModel,
|
||||
Gemma3TextScaledWordEmbedding,
|
||||
)
|
||||
from ..gemma3n.modeling_gemma3n import (
|
||||
Gemma3nCausalLMOutputWithPast,
|
||||
Gemma3nForConditionalGeneration,
|
||||
Gemma3nModel,
|
||||
Gemma3nModelOutputWithPast,
|
||||
Gemma3nMultimodalEmbedder,
|
||||
Gemma3nPreTrainedModel,
|
||||
Gemma3nRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..llama.modeling_llama import LlamaRotaryEmbedding
|
||||
from ..mixtral.modeling_mixtral import MixtralExperts
|
||||
from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function
|
||||
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
|
||||
class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring
|
||||
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
|
||||
r"""
|
||||
attention_mask (`torch.BoolTensor`, *optional*):
|
||||
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
|
||||
...
|
||||
|
||||
class Gemma4ClippableLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma4VisionConfig | Gemma4AudioConfig,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.use_clipped_linears = config.use_clipped_linears
|
||||
self.linear = nn.Linear(in_features, out_features, bias=False)
|
||||
|
||||
if self.use_clipped_linears:
|
||||
self.register_buffer("input_min", torch.tensor(-float("inf")))
|
||||
self.register_buffer("input_max", torch.tensor(float("inf")))
|
||||
...
|
||||
|
||||
class Gemma4RMSNorm(Gemma3nRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma4AudioRelPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal relative positional encoding for the audio encoder.
|
||||
|
||||
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
|
||||
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
|
||||
"""
|
||||
|
||||
inv_timescales: torch.Tensor
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
...
|
||||
|
||||
class Gemma4AudioAttention(nn.Module):
|
||||
"""Chunked local attention with relative position bias"""
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_logits_soft_cap = config.attention_logit_cap
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.q_scale = (self.head_dim**-0.5) / math.log(2)
|
||||
self.k_scale = math.log(1 + math.e) / math.log(2)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, norm_eps):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjection(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=1,
|
||||
out_channels=config.subsampling_conv_channels[0],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=config.subsampling_conv_channels[0],
|
||||
out_channels=config.subsampling_conv_channels[1],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
|
||||
...
|
||||
|
||||
class Gemma4AudioFeedForward(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
|
||||
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
|
||||
|
||||
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.gradient_clipping = config.gradient_clipping
|
||||
self.post_layer_scale = config.residual_weight
|
||||
...
|
||||
|
||||
class Gemma4AudioCausalConv1d(nn.Conv1d):
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_channels: int,
|
||||
# out_channels: int,
|
||||
# kernel_size: int,
|
||||
# # cache_key: str,
|
||||
# stride: int = 1,
|
||||
# dilation: int = 1,
|
||||
# bias: bool = True,
|
||||
# ):
|
||||
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
|
||||
# self.cache_key = cache_key
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioLightConv1d(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
|
||||
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
|
||||
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
|
||||
in_channels=config.hidden_size,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=config.conv_kernel_size,
|
||||
groups=config.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
...
|
||||
|
||||
class Gemma4AudioLayer(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.feed_forward1 = Gemma4AudioFeedForward(config)
|
||||
self.feed_forward2 = Gemma4AudioFeedForward(config)
|
||||
self.self_attn = Gemma4AudioAttention(config, layer_idx)
|
||||
self.lconv1d = Gemma4AudioLightConv1d(config)
|
||||
|
||||
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_out = Gemma4RMSNorm(config.hidden_size)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4VisionPatchEmbedder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.patch_size = config.patch_size
|
||||
self.position_embedding_size = config.position_embedding_size
|
||||
|
||||
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
|
||||
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
|
||||
|
||||
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Prepare patch positions map for matmul with positon embedding table."""
|
||||
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
|
||||
...
|
||||
|
||||
class Gemma4VisionPooler(nn.Module):
|
||||
"""Scaling and optional spatial pooling for vision encodings"""
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.root_hidden_size = self.hidden_size**0.5
|
||||
|
||||
def _avg_pool_by_positions(
|
||||
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
2D spatial pooling according to patch positions.
|
||||
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
|
||||
...
|
||||
|
||||
class Gemma4VisionMLP(Gemma3MLP):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__(self, config)
|
||||
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
|
||||
|
||||
|
||||
def apply_multidimensional_rope(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
unsqueeze_dim: int = 2,
|
||||
...
|
||||
|
||||
class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
@staticmethod
|
||||
def compute_default_rope_parameters(
|
||||
config: Gemma4VisionConfig | None = None,
|
||||
device: torch.device | None = None,
|
||||
seq_len: int | None = None,
|
||||
) -> tuple["torch.Tensor", float]:
|
||||
"""
|
||||
Computes the inverse frequencies according to the original RoPE implementation
|
||||
Args:
|
||||
config ([`~transformers.PreTrainedConfig`]):
|
||||
The model configuration.
|
||||
device (`torch.device`):
|
||||
The device to use for initialization of the inverse frequencies.
|
||||
...
|
||||
|
||||
class Gemma4VisionAttention(Gemma3Attention):
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__(self, config, layer_idx)
|
||||
del self.attn_logit_softcapping
|
||||
del self.sliding_window
|
||||
del self.is_sliding
|
||||
self.scaling = 1.0
|
||||
self.is_causal = False
|
||||
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
|
||||
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
|
||||
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
|
||||
self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
|
||||
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoderLayer(Gemma3DecoderLayer):
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__(self, config, layer_idx)
|
||||
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4VisionMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
...
|
||||
|
||||
class Gemma4TextMLP(Gemma3MLP):
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
|
||||
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
|
||||
super().__init__()
|
||||
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
|
||||
|
||||
|
||||
class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding):
|
||||
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
|
||||
nn.Module.__init__(self)
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
...
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class Gemma4TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.is_sliding = self.layer_type == "sliding_attention"
|
||||
self.sliding_window = config.sliding_window if self.is_sliding else None
|
||||
|
||||
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
|
||||
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
|
||||
...
|
||||
|
||||
class Gemma4TextExperts(MixtralExperts):
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.intermediate_dim = config.moe_intermediate_size
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
|
||||
class Gemma4TextRouter(nn.Module):
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.scalar_root_size = self.hidden_size**-0.5
|
||||
...
|
||||
|
||||
class Gemma4TextDecoderLayer(Gemma3DecoderLayer):
|
||||
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4TextMLP(config, layer_idx)
|
||||
self.register_buffer("layer_scalar", torch.ones(1))
|
||||
|
||||
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
|
||||
if self.hidden_size_per_layer_input:
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
|
||||
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
|
||||
self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
# ---- Model Classes ----
|
||||
|
||||
|
||||
class Gemma4PreTrainedModel(Gemma3nPreTrainedModel):
|
||||
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
|
||||
input_modalities = ("image", "text", "video", "audio")
|
||||
_can_record_outputs = None # override
|
||||
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
|
||||
|
||||
@torch.no_grad()
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
|
||||
class Gemma4TextModel(Gemma3TextModel):
|
||||
config: Gemma4TextConfig
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
|
||||
"hidden_states": Gemma4TextDecoderLayer,
|
||||
"attentions": Gemma4TextAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
|
||||
class Gemma4ForCausalLM(Gemma3ForCausalLM):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
# Grab the ones from the child
|
||||
self._keys_to_ignore_on_load_unexpected = [
|
||||
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
|
||||
]
|
||||
|
||||
|
||||
class Gemma4AudioModel(Gemma4PreTrainedModel):
|
||||
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
|
||||
...
|
||||
|
||||
class Gemma4VisionModel(Gemma4PreTrainedModel):
|
||||
"""The Gemma 4 Vision Encoder."""
|
||||
|
||||
config = Gemma4VisionConfig
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Gemma4VisionEncoderLayer,
|
||||
"attentions": Gemma4VisionAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__(config)
|
||||
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
|
||||
self.encoder = Gemma4VisionEncoder(config)
|
||||
self.pooler = Gemma4VisionPooler(config)
|
||||
...
|
||||
|
||||
class Gemma4MultimodalEmbedder(Gemma3nMultimodalEmbedder):
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
|
||||
text_config: Gemma4TextConfig,
|
||||
):
|
||||
# Audio tower may use a different output dimension (output_proj_dims) than the
|
||||
# internal hidden_size. Use the tower-specific dimension if specified.
|
||||
super().__init__(multimodal_config, text_config)
|
||||
del self.embedding
|
||||
del self.hard_embedding_norm
|
||||
del self.soft_embedding_norm
|
||||
del self.vocab_offset
|
||||
del self.vocab_size
|
||||
...
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
token_type_ids: torch.Tensor | None,
|
||||
image_group_ids: torch.Tensor | None,
|
||||
) -> Callable | None:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
"""
|
||||
# Do not return an additional mask in this case
|
||||
if token_type_ids is None:
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
seq_length = image_group_ids.shape[-1]
|
||||
...
|
||||
|
||||
def create_causal_mask_mapping(
|
||||
config: PreTrainedConfig,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_values: Cache | None,
|
||||
position_ids: torch.Tensor | None,
|
||||
mm_token_type_ids: torch.Tensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
is_training: bool = False,
|
||||
is_first_iteration: bool | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
|
||||
language modeling head.
|
||||
"""
|
||||
)
|
||||
class Gemma4Model(Gemma3nModel):
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
del self.vision_tower
|
||||
del self.embed_vision
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None
|
||||
self.embed_vision = (
|
||||
Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
||||
head.
|
||||
"""
|
||||
)
|
||||
class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
# Grab the ones from the child
|
||||
self._keys_to_ignore_on_load_unexpected = [
|
||||
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
|
||||
...
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...audio_utils import AudioInput
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, make_nested_list_of_images
|
||||
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import auto_docstring, is_vision_available, logging
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from .image_processing_pil_gemma4 import Gemma4ImageProcessorKwargs, get_aspect_ratio_preserving_size
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma4ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Gemma4ImageProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
"return_mm_token_type_ids": True,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"do_convert_rgb": True,
|
||||
},
|
||||
"audio_kwargs": {},
|
||||
"videos_kwargs": {"return_metadata": True},
|
||||
}
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@requires(backends=("vision",))
|
||||
class Gemma4Processor(ProcessorMixin):
|
||||
def __init__(
|
||||
self,
|
||||
feature_extractor,
|
||||
image_processor,
|
||||
tokenizer,
|
||||
video_processor,
|
||||
chat_template=None,
|
||||
image_seq_length: int = 280,
|
||||
audio_seq_length: int = 750,
|
||||
audio_ms_per_token: int = 40,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
image_seq_length (`int`, *optional*, defaults to 280):
|
||||
The number of soft tokens per image used for placeholder expansion.
|
||||
audio_seq_length (`int`, *optional*, defaults to 750):
|
||||
The maximum number of audio soft tokens per audio segment. Serves as an
|
||||
upper-bound cap when dynamic audio token counts are computed.
|
||||
audio_ms_per_token (`int`, *optional*, defaults to 40):
|
||||
Milliseconds of audio per output soft token. Used to dynamically compute
|
||||
the number of audio placeholder tokens as ``ceil(duration_ms / audio_ms_per_token)``.
|
||||
The default of 40 comes from the SSCP convolution's 4× time reduction on 10ms frames.
|
||||
"""
|
||||
self.image_seq_length = image_seq_length
|
||||
self.image_token_id = tokenizer.image_token_id
|
||||
self.boi_token = tokenizer.boi_token
|
||||
self.eoi_token = tokenizer.eoi_token
|
||||
self.image_token = tokenizer.image_token
|
||||
|
||||
# FIXME: add the token to config and ask Ryan to re-upload
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<|video|>"]})
|
||||
self.video_token = "<|video|>"
|
||||
self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
|
||||
# Audio token handling, mirroring the vision pattern.
|
||||
# audio_seq_length serves as the maximum cap on the number of audio soft tokens
|
||||
# any single audio segment can produce. With dynamic audio tokens, the actual
|
||||
# number of placeholders inserted per audio is computed from the audio duration.
|
||||
self.audio_seq_length = audio_seq_length
|
||||
# Milliseconds of audio per output soft token. The default of 40 comes from the
|
||||
# SSCP convolution's 4× time reduction applied to 10ms mel spectrogram frames.
|
||||
self.audio_ms_per_token = audio_ms_per_token
|
||||
self.audio_token_id = getattr(tokenizer, "audio_token_id", None)
|
||||
self.audio_token = getattr(tokenizer, "audio_token", None)
|
||||
self.boa_token = getattr(tokenizer, "boa_token", None)
|
||||
self.eoa_token = getattr(tokenizer, "eoa_token", None)
|
||||
|
||||
super().__init__(
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@auto_docstring
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput | None = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio: AudioInput | None = None,
|
||||
videos: VideoInput | None = None,
|
||||
**kwargs: Unpack[Gemma4ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
if text is None and images is None and audio is None and videos is None:
|
||||
raise ValueError("Provide at least one of `text`, `images`, `audio`, or `videos`.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Gemma4ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
image_inputs = {}
|
||||
if images is not None:
|
||||
images = self.image_processor.fetch_images(images)
|
||||
batched_images = make_nested_list_of_images(images)
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
num_soft_tokens = image_inputs.pop("num_soft_tokens_per_image")
|
||||
|
||||
# Create empty text to be replaced with placeholders
|
||||
if not text:
|
||||
text = [" ".join([self.image_token] * len(images)) for images in batched_images]
|
||||
|
||||
if len(batched_images) != len(text):
|
||||
raise ValueError(
|
||||
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
||||
)
|
||||
|
||||
replacements = [f"{self.boi_token}{self.image_token * n}{self.eoi_token}" for n in num_soft_tokens]
|
||||
replacements_iter = iter(replacements)
|
||||
|
||||
# Expand image_token placeholders to per-image soft token sequences.
|
||||
# re.sub never re-scans replaced text, so it is safe
|
||||
pattern = re.escape(self.image_token)
|
||||
text = [re.sub(pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
|
||||
|
||||
# Process video inputs in same way
|
||||
video_inputs = {}
|
||||
if videos is not None:
|
||||
video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
num_video_tokens = video_inputs.pop("num_soft_tokens_per_video")
|
||||
|
||||
# If user has not requested video metadata, pop it so it isn't returned
|
||||
if not kwargs.get("return_metadata"):
|
||||
video_metadata = video_inputs.pop("video_metadata")
|
||||
else:
|
||||
video_metadata = video_inputs["video_metadata"]
|
||||
|
||||
video_replacements = []
|
||||
for metadata, n_tokens in zip(video_metadata, num_video_tokens):
|
||||
if metadata.fps is None:
|
||||
logger.warning_once(
|
||||
"Gemma 4 requires frame timestamps to construct prompts, but the `fps` of the input video "
|
||||
"could not be inferred. Probably `video_metadata` was missing from inputs and you passed "
|
||||
"pre-sampled frames. Defaulting to `fps=24`. Please provide `video_metadata` for more "
|
||||
"accurate results."
|
||||
)
|
||||
metadata.fps = 24 if metadata.fps is None else metadata.fps
|
||||
# mm:ss format for timestamps
|
||||
timestamp_str = [
|
||||
f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" for seconds in metadata.timestamps
|
||||
]
|
||||
video_replacements.append(
|
||||
" ".join(
|
||||
[f"{t} {self.boi_token}{self.video_token * n_tokens}{self.eoi_token}" for t in timestamp_str]
|
||||
)
|
||||
)
|
||||
|
||||
video_replacements = iter(video_replacements)
|
||||
pattern = re.escape(self.video_token)
|
||||
text = [re.sub(pattern, lambda _: next(video_replacements), prompt) for prompt in text]
|
||||
|
||||
# Process audio inputs
|
||||
audio_inputs = {}
|
||||
if audio is not None:
|
||||
if self.audio_token is None or self.boa_token is None or self.eoa_token is None:
|
||||
raise ValueError(
|
||||
"Audio inputs were provided, but the tokenizer does not have an `audio_token` defined."
|
||||
)
|
||||
|
||||
# Normalize audio input to list of waveforms
|
||||
if isinstance(audio, np.ndarray) and audio.ndim == 1:
|
||||
audio = [audio]
|
||||
|
||||
# TODO: Add tests for audio-only processor inputs.
|
||||
if not text:
|
||||
text = [self.audio_token] * len(audio)
|
||||
|
||||
# Dynamic audio token expansion wihtout padding:
|
||||
# * Extract audio features with feature extractor;
|
||||
# * Compute precise per-audio token counts from the waveform duration;
|
||||
# * Generate full audio token sequence for each computed audio length;
|
||||
# * Expand text prompts with full audio token sequences.
|
||||
audio_kwargs = output_kwargs.get("audio_kwargs", {})
|
||||
audio_inputs = self.feature_extractor(audio, **audio_kwargs)
|
||||
sampling_rate = self.feature_extractor.sampling_rate
|
||||
num_audio_tokens = [self._compute_audio_num_tokens(a, sampling_rate) for a in audio]
|
||||
replacements = [f"{self.boa_token}{self.audio_token * n}{self.eoa_token}" for n in num_audio_tokens]
|
||||
replacements_iter = iter(replacements)
|
||||
audio_pattern = re.escape(self.audio_token)
|
||||
text = [re.sub(audio_pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
|
||||
# Check special tokens for all active modalities
|
||||
active_modalities = []
|
||||
if images is not None:
|
||||
active_modalities.append("image")
|
||||
if videos is not None:
|
||||
active_modalities.append("video")
|
||||
if audio is not None:
|
||||
active_modalities.append("audio")
|
||||
if active_modalities:
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=active_modalities)
|
||||
|
||||
if return_mm_token_type_ids:
|
||||
text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
|
||||
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **image_inputs, **audio_inputs, **video_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def _compute_audio_num_tokens(self, audio_waveform, sampling_rate: int) -> int:
|
||||
"""Compute the number of audio soft tokens for a single waveform.
|
||||
|
||||
Replicates the exact sequence-length arithmetic of the audio encoder
|
||||
so that the processor inserts the correct number of placeholder tokens.
|
||||
The computation mirrors:
|
||||
|
||||
1. Mel framing via ``_unfold`` in ``Gemma4AudioFeatureExtractor``
|
||||
2. Two ``Conv2d`` subsampling layers in ``Gemma4AudioSubSampleConvProjection``
|
||||
(each: kernel=3, stride=2, semicausal padding top=1, bottom=1)
|
||||
|
||||
The result is capped at ``self.audio_seq_length`` (the configured maximum).
|
||||
|
||||
Args:
|
||||
audio_waveform: A 1-D numpy array or list containing the raw audio samples.
|
||||
sampling_rate: The sampling rate of the audio waveform in Hz.
|
||||
|
||||
Returns:
|
||||
The number of audio soft tokens to insert as placeholders.
|
||||
"""
|
||||
num_samples = len(audio_waveform)
|
||||
|
||||
# Step 1: Mel frames (matches feature_extraction_gemma4.py _unfold)
|
||||
frame_length = int(round(sampling_rate * 20.0 / 1000.0)) # 320 @ 16kHz
|
||||
hop_length = int(round(sampling_rate * 10.0 / 1000.0)) # 160 @ 16kHz
|
||||
frame_size_for_unfold = frame_length + 1 # 321
|
||||
|
||||
# The feature extractor prepends (frame_length // 2) zero samples as
|
||||
# semicausal time-padding before the unfold. We must include this to
|
||||
# match the actual number of mel frames it produces.
|
||||
pad_left = frame_length // 2 # 160 @ 16kHz
|
||||
padded_samples = num_samples + pad_left
|
||||
num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1
|
||||
|
||||
if num_mel_frames <= 0:
|
||||
return 0
|
||||
|
||||
# Step 2: Two SSCP conv layers (kernel=3, stride=2, semicausal pad top=1, bottom=1)
|
||||
# Each layer: T_out = (T_in + pad_top + pad_bottom - kernel) // stride + 1
|
||||
t = num_mel_frames
|
||||
for _ in range(2):
|
||||
t_padded = t + 2 # pad_top=1, pad_bottom=1
|
||||
t = (t_padded - 3) // 2 + 1
|
||||
|
||||
# Cap at the configured maximum
|
||||
return min(t, self.audio_seq_length)
|
||||
|
||||
def _get_num_multimodal_tokens(self, image_sizes=None, audio_lengths=None, **kwargs):
|
||||
"""
|
||||
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||
|
||||
Args:
|
||||
image_sizes (`list[list[int]]`, *optional*):
|
||||
The input sizes formatted as (height, width) per each image.
|
||||
audio_lengths (`list[int]`, *optional*):
|
||||
The lengths of audio inputs in number of samples. Used to dynamically
|
||||
compute per-audio token counts.
|
||||
|
||||
Returns:
|
||||
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||
input modalities, along with other useful data.
|
||||
"""
|
||||
|
||||
images_kwargs = Gemma4ProcessorKwargs._defaults.get("images_kwargs", {})
|
||||
images_kwargs.update(kwargs)
|
||||
patch_size = images_kwargs.get("patch_size", None) or self.image_processor.patch_size
|
||||
pooling_kernel_size = (
|
||||
images_kwargs.get("pooling_kernel_size", None) or self.image_processor.pooling_kernel_size
|
||||
)
|
||||
max_soft_tokens = images_kwargs.get("max_soft_tokens", None) or self.image_processor.max_soft_tokens
|
||||
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
vision_data = {}
|
||||
if image_sizes is not None:
|
||||
num_image_tokens = []
|
||||
for image_size in image_sizes:
|
||||
target_h, target_w = get_aspect_ratio_preserving_size(
|
||||
height=image_size[0],
|
||||
width=image_size[1],
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
patch_height = target_h // patch_size
|
||||
patch_width = target_w // patch_size
|
||||
num_image_tokens.append(patch_height * patch_width // pooling_kernel_size**2)
|
||||
|
||||
num_image_patches = [1] * len(image_sizes)
|
||||
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||
|
||||
if audio_lengths is not None:
|
||||
# Dynamically compute per-audio token counts from sample lengths.
|
||||
# audio_lengths are in number of samples; assume default sampling rate.
|
||||
sampling_rate = getattr(self.feature_extractor, "sampling_rate", 16_000)
|
||||
num_audio_tokens = [
|
||||
self._compute_audio_num_tokens(np.zeros(length), sampling_rate) for length in audio_lengths
|
||||
]
|
||||
vision_data.update({"num_audio_tokens": num_audio_tokens})
|
||||
|
||||
return MultiModalData(**vision_data)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
model_input_names = super().model_input_names
|
||||
model_input_names = [
|
||||
name
|
||||
for name in model_input_names
|
||||
if name not in ["num_soft_tokens_per_image", "num_soft_tokens_per_video"]
|
||||
]
|
||||
|
||||
# Include audio feature extractor input names if available
|
||||
if self.feature_extractor is not None:
|
||||
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||
model_input_names.extend([name for name in feature_extractor_input_names if name not in model_input_names])
|
||||
|
||||
return model_input_names + ["mm_token_type_ids"]
|
||||
|
||||
|
||||
__all__ = ["Gemma4Processor"]
|
||||
@@ -0,0 +1,237 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
|
||||
from ...video_utils import VideoInput
|
||||
from .image_processing_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
elif is_torchvision_available():
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma4VideoProcessorKwargs(VideosKwargs, total=False):
|
||||
"""
|
||||
patch_size (`int`, *optional*):
|
||||
Size of each image patch in pixels.
|
||||
max_soft_tokens (`int`, *optional*):
|
||||
Maximum number of soft (vision) tokens per video frame.
|
||||
Must be one of {70, 140, 280, 560, 1120}.
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
"""
|
||||
|
||||
patch_size: int
|
||||
max_soft_tokens: int
|
||||
pooling_kernel_size: int
|
||||
|
||||
|
||||
def convert_video_to_patches(video: "torch.Tensor", patch_size: int) -> "torch.Tensor":
|
||||
"""
|
||||
Convert 4D tensor video of shape (num_frames, num_channels, height, width) into 3D tensor of patches of shape
|
||||
(num_frames, num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
num_frames, num_channels, height, width = video.shape
|
||||
num_patches_height = height // patch_size
|
||||
num_patches_width = width // patch_size
|
||||
patched_video = video.reshape(
|
||||
num_frames, num_channels, num_patches_height, patch_size, num_patches_width, patch_size
|
||||
)
|
||||
patched_video = patched_video.permute(0, 2, 4, 3, 5, 1)
|
||||
patched_video = patched_video.reshape(num_frames, num_patches_height * num_patches_width, -1)
|
||||
return patched_video
|
||||
|
||||
|
||||
def pad_to_max_patches(
|
||||
video: "torch.Tensor", positions: "torch.Tensor", target_length: int
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Pad the video along to max number of patches
|
||||
"""
|
||||
current_length = video.shape[1]
|
||||
padding_length = target_length - current_length
|
||||
if padding_length > 0:
|
||||
padding = [0, 0, 0, padding_length, 0, 0]
|
||||
pos_padding = (0, 0, 0, padding_length, 0, 0)
|
||||
video = torch.nn.functional.pad(video, padding, mode="constant", value=0)
|
||||
positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1)
|
||||
return video, positions
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a Gemma4 video processor that samples frames from videos for use with the Gemma4 model.",
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||
)
|
||||
class Gemma4VideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = [0.0, 0.0, 0.0]
|
||||
image_std = [1.0, 1.0, 1.0]
|
||||
size = None
|
||||
default_to_square = True
|
||||
do_convert_rgb = True
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
num_frames = 32
|
||||
do_sample_frames = True
|
||||
patch_size = 16
|
||||
max_soft_tokens = 70
|
||||
pooling_kernel_size = 3
|
||||
valid_kwargs = Gemma4VideoProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos", "video_position_ids"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Gemma4VideoProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
|
||||
|
||||
def _validate_preprocess_kwargs(self, **kwargs):
|
||||
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
|
||||
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
|
||||
# parameter. Temporarily disable do_resize so the base validation
|
||||
# doesn't require `size` to be set.
|
||||
kwargs["do_resize"] = False
|
||||
super()._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
def aspect_ratio_preserving_resize(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
resample: F.InterpolationMode,
|
||||
) -> torch.Tensor:
|
||||
height, width = video.shape[-2], video.shape[-1]
|
||||
target_height, target_width = get_aspect_ratio_preserving_size(
|
||||
height=height,
|
||||
width=width,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
|
||||
if target_height == height and target_width == width:
|
||||
return video
|
||||
|
||||
return F.resize(
|
||||
video,
|
||||
size=[target_height, target_width],
|
||||
interpolation=resample,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs: Unpack[Gemma4VideoProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(videos, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
resample: "F.InterpolationMode | int | None",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
return_tensors: str | TensorType | None,
|
||||
patch_size: int | None = None,
|
||||
max_soft_tokens: int | None = None,
|
||||
pooling_kernel_size: int | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
|
||||
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
pixel_values = []
|
||||
position_ids = []
|
||||
num_soft_tokens_per_video = []
|
||||
num_frames = 1
|
||||
|
||||
for video in videos:
|
||||
if do_resize:
|
||||
video = self.aspect_ratio_preserving_resize(
|
||||
video=video,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
video = self.rescale_and_normalize(video, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||
|
||||
num_frames = video.shape[0]
|
||||
patch_height = video.shape[-2] // patch_size
|
||||
patch_width = video.shape[-1] // patch_size
|
||||
patches = convert_video_to_patches(video, patch_size)
|
||||
num_soft_tokens_per_video.append(patches.shape[1] // pooling_kernel_size**2)
|
||||
|
||||
device = video.device
|
||||
patch_grid = torch.meshgrid(
|
||||
torch.arange(patch_width, device=device),
|
||||
torch.arange(patch_height, device=device),
|
||||
indexing="xy",
|
||||
)
|
||||
stacked_grid = torch.stack(patch_grid, dim=-1)
|
||||
real_positions = stacked_grid.reshape(patches.shape[1], 2)
|
||||
real_positions = real_positions[None, ...].repeat(num_frames, 1, 1)
|
||||
|
||||
patches, positions = pad_to_max_patches(patches, real_positions, max_patches)
|
||||
pixel_values.append(patches)
|
||||
position_ids.append(positions)
|
||||
|
||||
# Stack into batch tensors
|
||||
pixel_values = torch.stack(pixel_values, dim=0) # (num_videos, num_frames, max_patches, patch_pixels)
|
||||
position_ids = torch.stack(position_ids, dim=0) # (num_videos, num_frames, max_patches, 2)
|
||||
|
||||
data = {
|
||||
"pixel_values_videos": pixel_values,
|
||||
"video_position_ids": position_ids,
|
||||
"num_soft_tokens_per_video": num_soft_tokens_per_video,
|
||||
}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma4VideoProcessor"]
|
||||
Reference in New Issue
Block a user