docs: add canonical tooling corpus (147 files) from Google/HF/frameworks

Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mortdecai
2026-04-18 12:24:48 -04:00
parent 5011059f5d
commit eecebe7ef5
149 changed files with 181297 additions and 0 deletions
+161
View File
@@ -0,0 +1,161 @@
# Gemma 4 — Hugging Face Canonical Tooling
Downloaded April 2026. First-party Google/HF content only. No weights, no third-party fine-tunes.
## What's here
### `model-cards/`
Verbatim `README.md` from every `google/gemma-4-*` repo (raw endpoint, ungated). Plus the chat template and tokenizer config for two representative variants (31B-it and E4B-it). All eight model cards have identical body text; they differ only in the `pipeline_tag:` YAML frontmatter and size-specific tables.
| File | What it demonstrates |
|------|----------------------|
| `gemma-4-31B-it-README.md` | Flagship dense (33B) instruction-tuned. Full "how to use" from Google+HF. |
| `gemma-4-31B-README.md` | Base (pretrained) variant of the above. |
| `gemma-4-26B-A4B-it-README.md` | MoE (26B params, 4B active) instruction-tuned. The "A4B" = 4B active. |
| `gemma-4-26B-A4B-README.md` | Base MoE. |
| `gemma-4-E4B-it-README.md` | Edge-sized 8B instruction-tuned. Multimodal including audio. |
| `gemma-4-E4B-README.md` | Base E4B. |
| `gemma-4-E2B-it-README.md` | Smallest (5B) instruction-tuned, mobile-targeted. |
| `gemma-4-E2B-README.md` | Base E2B. |
| `gemma-4-31B-it-chat_template.jinja` | **Canonical chat template.** 16KB Jinja — handles system/user/model/tool roles, thinking channel, tool calls, image/audio/video tokens. |
| `gemma-4-E4B-it-chat_template.jinja` | Near-identical to 31B's (131-byte difference — likely one whitespace-sensitive thing around audio handling). |
| `gemma-4-31B-it-tokenizer_config.json` | **Special-token inventory + `response_schema` regex machinery.** See "New capabilities" below. |
| `gemma-4-E4B-it-tokenizer_config.json` | Same shape. |
### `transformers/`
Files under `src/transformers/models/gemma4/` on `huggingface/transformers@main`. Full files for small ones; outlines (signatures + first 12 lines per class/def) for the two large ones.
| File | Lines | What |
|------|-------|------|
| `__init__.py` | 33 | Module exports |
| `configuration_gemma4.py` | 352 | `Gemma4Config`, `Gemma4TextConfig`, `Gemma4AudioConfig`, `Gemma4VisionConfig` — all hyperparams |
| `processing_gemma4.py` | 366 | `Gemma4Processor` — the thing `AutoProcessor.from_pretrained` returns. Includes `parse_response()` |
| `feature_extraction_gemma4.py` | 298 | Audio feature extraction (mel spec, padding) |
| `image_processing_gemma4.py` | 220 | Tensor-backed image preprocessing |
| `image_processing_pil_gemma4.py` | 278 | PIL-backed variant (slower fallback) |
| `video_processing_gemma4.py` | 237 | Frame sampling + stitching to image tokens |
| `modeling_gemma4-OUTLINE.py` | 723 | Outline of the 2657-line modeling file (43 classes: attention, MoE, audio encoder, vision tower, all LM heads) |
| `modular_gemma4-OUTLINE.py` | 563 | Outline of the modular source file — shows Gemma4 **inherits from Gemma3n classes** (RMSNorm, attention blocks etc.) confirming the 3n→4 lineage |
Full files: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gemma4
### `recipes/`
From `huggingface/huggingface-gemma-recipes` — the canonical HF recipe repo. The only Gemma 4-specific recipe as of April 2026 is one notebook; the rest is Gemma 3n which is architecturally the parent of Gemma 4.
| File | What |
|------|------|
| `notebooks/Gemma4_E2B-Multimodal.ipynb` | **The one first-party Gemma-4 recipe.** Original ipynb. 36 cells: image, video, audio, function calling, object detection with `box_2d`, any-to-any pipeline, captioning. |
| `notebooks/Gemma4_E2B-Multimodal-extracted.py` | Same notebook flattened to readable .py for grep/diff. |
| `scripts/ft_gemma3n_image_trl.py` | TRL SFT fine-tune of Gemma 3n on images. Direct precursor to Gemma 4 SFT. |
| `scripts/ft_gemma3n_image_vt.py` | Vision+text fine-tune without TRL (pure Transformers Trainer). |
| `scripts/ft_gemma3n_audio_vt.py` | Audio+text fine-tune. |
| `scripts/gemma3n_fine_tuning_on_all_modalities.py` | All-modalities SFT script — template for full Gemma-4 all-modal SFT. |
| `scripts/carla_vlm_gemma.py` | CARLA driving sim VLM example using Gemma. |
### `trl/`
**Empty as of April 2026.** Searched `huggingface/trl/examples/scripts` — only `sft_gemma3.py` and `sft_vlm_gemma3.py` exist, no gemma4 yet. The gemma-recipes repo's `ft_gemma3n_image_trl.py` is the closest first-party TRL pattern; it is saved under `recipes/scripts/` above.
### `peft/`
**Empty as of April 2026.** `huggingface/peft/examples` has no gemma-specific directory. The canonical HF PEFT guide for Gemma is the blog post `gemma-peft.md`, saved under `blog/` below. It covers Gemma 1 but the LoRA target-module patterns apply unchanged to Gemma 4 (same `q_proj/k_proj/v_proj/o_proj` naming).
### `blog/`
| File | What |
|------|------|
| `gemma4-blog.md` | **"Welcome Gemma 4: Frontier multimodal intelligence on device"** — the HF launch blog. 764 lines. Authored by merve. Covers architecture, capabilities, transformers usage, HF Inference API, llama.cpp/MLX quantization, thinking mode examples. |
| `gemma-peft-blog.md` | "Fine-Tuning Gemma Models in Hugging Face" — the PEFT/LoRA recipe blog (gemma-agnostic, target modules unchanged for Gemma 4). |
### `spaces/`
The official HF-run interactive demo Spaces.
| File | What |
|------|------|
| `huggingface-projects_gemma-4-31b-it-app.py` | Official 31B demo (Gradio 6 chat + multimodal). |
| `huggingface-projects_gemma-4-e4b-it-app.py` | Official E4B demo. **More illustrative** — shows the full multimodal+thinking pattern in ~320 lines. |
| `*-requirements.txt` | Pinned deps. **`transformers==5.5.4`** (as of 2026-04-18) — that's the minimum version for Gemma 4 in transformers main line. |
---
## New capabilities the HF integration exposes that weren't in the existing corpus
1. **`AutoModelForMultimodalLM`** — new transformers AutoClass, not `AutoModelForCausalLM`. Required to get any-to-any routing (text+image+audio+video in, text out). The corpus's `CORPUS_capabilities.md` should note this.
2. **`processor.parse_response(text) -> dict`** — built into `Gemma4Processor`. Returns `{thinking, content, tool_calls}` parsed from raw decoded output. Driven by regexes declared in `tokenizer_config.json` under `response_schema` (new HF feature using `x-regex`, `x-regex-iterator`, and a custom `x-parser: gemma4-tool-call`). **You no longer need to hand-roll tool-call regex parsing** if you use the HF processor — this is the HF-canonical replacement for the manual parsing done in `CORPUS_tool_calling_format.md`.
3. **`enable_thinking=True`** — a kwarg to `processor.apply_chat_template()`. When set, injects `<|think|>` at the top of the system turn. **This is how you turn reasoning mode on** through the HF API. Not documented in the existing corpus.
4. **`load_audio_from_video=True`** — another `apply_chat_template` kwarg. Pulls the audio track out of a video URL and feeds it as audio tokens alongside sampled frames. Only relevant for E2B/E4B which have audio; the notebook comment explicitly calls this out.
5. **`pipeline("any-to-any", model=...)`** — a new HF pipeline task registered for Gemma 4. Accepts the chat-style messages list directly. Easiest one-liner for multimodal inference.
6. **Object detection via `box_2d` JSON** — prompting with "What's the bounding box for the X?" returns `[{"box_2d": [ymin, xmin, ymax, xmax], "label": "..."}]` in a 1000x1000 normalized coordinate frame, with images resized to multiples of 48 pixels. This is a Gemma-4-specific convention the notebook demonstrates. Corpus doesn't cover this.
7. **Thinking delimiters are `<|channel>thought...<channel|>`** — not `<thinking>...</thinking>` like some other open-weights models. The Space app explicitly strips these to pass to Gradio 6's `reasoning_tags` for collapsible thinking UI.
8. **Breaking change in role/turn markers vs Gemma 3** — Gemma 3 used `<start_of_turn>user ... <end_of_turn>`. Gemma 4 uses `<|turn>user\n ... <turn|>`. Tokenizer config:
- `sot_token`: `<|turn>` (start of turn)
- `eot_token`: `<turn|>` (end of turn)
- Role after `<|turn>` can be `system`, `user`, `model`, or `tool`.
- `enable_thinking` injects a `<|think|>` marker into the first system turn.
Anything in the homelab that hard-codes `<start_of_turn>` for Gemma needs to branch on family version. Worth adding to `GOTCHAS.md`.
---
## Canonical chat template format
**Source of truth:** the two `.jinja` files in `model-cards/`. Use them directly — **do not reimplement.** The tokenizer loads them automatically:
```python
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("google/gemma-4-E4B-it")
inputs = processor.apply_chat_template(
messages,
tools=[WEATHER_TOOL], # optional; OpenAI-style tool schema
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
enable_thinking=True, # turns on reasoning, injects <|think|>
load_audio_from_video=False, # only for video inputs
)
output = model.generate(**inputs, max_new_tokens=1000)
generated = processor.decode(output[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
result = processor.parse_response(generated)
# → {"thinking": "...", "content": "...", "tool_calls": [...]}
```
### Wire format that the template produces
```
<bos><|turn>system
<|think|>
{system prompt here if any}
<|tool>declaration:get_weather{city:{type:<|"|>STRING<|"|>,description:<|"|>...<|"|>}}<tool|>
<turn|>
<|turn>user
{user text}
<|image|> ← placeholder for each image
<|audio|> ← placeholder for each audio
<|video|> ← placeholder for each video
<turn|>
<|turn>model
<|channel>thought
{reasoning text}
<channel|>
<|tool_call>call:get_weather{city:<|"|>London<|"|>}<tool_call|>
<|tool_response>response:get_weather{temperature:15}<tool_response|>
{final content}
<turn|>
```
Every Gemma-4-specific token appears in `tokenizer_config.json`. The `apply_chat_template` call + the `response_schema` + `parse_response()` round-trip means **homelab code should never hand-emit these tokens** — always go through the processor.
---
## Source URLs (first-party only)
- Model collection: https://huggingface.co/collections/google/gemma-4
- transformers gemma4 dir: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gemma4
- Recipes repo: https://github.com/huggingface/huggingface-gemma-recipes
- Launch blog: https://huggingface.co/blog/gemma4
- Official 31B Space: https://huggingface.co/spaces/huggingface-projects/gemma-4-31b-it
- Official E4B Space: https://huggingface.co/spaces/huggingface-projects/gemma-4-e4b-it
+207
View File
@@ -0,0 +1,207 @@
---
title: Fine-Tuning Gemma Models in Hugging Face
thumbnail: /blog/assets/gemma-peft/thumbnail.png
authors:
- user: svaibhav
guest: true
- user: alanwaketan
guest: true
- user: ybelkada
- user: ArthurZ
---
# Fine-Tuning Gemma Models in Hugging Face
We recently announced that [Gemma](https://huggingface.co/blog/gemma), the open weights language model from Google Deepmind, is available for the broader open-source community via Hugging Face. Its available in 2 billion and 7 billion parameter sizes with pretrained and instruction-tuned flavors. Its available on Hugging Face, supported in TGI, and easily accessible for deployment and fine-tuning in the Vertex Model Garden and Google Kubernetes Engine.
<div class="flex items-center justify-center">
<img src="/blog/assets/gemma-peft/Gemma-peft.png" alt="Gemma Deploy">
</div>
The Gemma family of models also happens to be well suited for prototyping and experimentation using the free GPU resource available via Colab. In this post we will briefly review how you can do [Parameter Efficient FineTuning (PEFT)](https://huggingface.co/blog/peft) for Gemma models, using the Hugging Face Transformers and PEFT libraries on GPUs and Cloud TPUs for anyone who wants to fine-tune Gemma models on their own dataset.
## Why PEFT?
The default (full weight) training for language models, even for modest sizes, tends to be memory and compute-intensive. On one hand, it can be prohibitive for users relying on openly available compute platforms for learning and experimentation, such as Colab or Kaggle. On the other hand, and even for enterprise users, the cost of adapting these models for different domains is an important metric to optimize. PEFT, or parameter-efficient fine tuning, is a popular technique to accomplish this at low cost.
## PyTorch on GPU and TPU
Gemma models in Hugging Face `transformers` are optimized for both PyTorch and PyTorch/XLA. This enables both TPU and GPU users to access and experiment with Gemma models as needed. Together with the Gemma release, we have also improved the [FSDP](https://engineering.fb.com/2021/07/15/open-source/fsdp/) experience for PyTorch/XLA in Hugging Face. This [FSDP via SPMD](https://github.com/pytorch/xla/issues/6379) integration also allows other Hugging Face models to take advantage of TPU acceleration via PyTorch/XLA. In this post, we will focus on PEFT, and more specifically on Low-Rank Adaptation (LoRA), for Gemma models. For a more comprehensive set of LoRA techniques, we encourage readers to review the [Scaling Down to Scale Up, from Lialin et al.](https://arxiv.org/pdf/2303.15647.pdf) and [this excellent post](https://pytorch.org/blog/finetune-llms/) post by Belkada et al.
## Low-Rank Adaptation for Large Language Models
Low-Rank Adaptation (LoRA) is one of the parameter-efficient fine-tuning techniques for large language models (LLMs). It addresses just a fraction of the total number of model parameters to be fine-tuned, by freezing the original model and only training adapter layers that are decomposed into low-rank matrices. The [PEFT library](https://github.com/huggingface/peft) provides an easy abstraction that allows users to select the model layers where adapter weights should be applied.
```python
from peft import LoraConfig
lora_config = LoraConfig(
r=8,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
```
In this snippet, we refer to all `nn.Linear` layers as the target layers to be adapted.
In the following example, we will leverage [QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes), from [Dettmers et al.](https://arxiv.org/abs/2305.14314), in order to quantize the base model in 4-bit precision for a more memory efficient fine-tuning protocol. The model can be loaded with QLoRA by first installing the `bitsandbytes` library on your environment, and then passing a `BitsAndBytesConfig` object to `from_pretrained` when loading the model.
## Before we begin
In order to access Gemma model artifacts, users are required to accept [the consent form](https://huggingface.co/google/gemma-7b-it).
Now lets get started with the implementation.
## Learning to quote
Assuming that you have submitted the consent form, you can access the model artifacts from the [Hugging Face Hub](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
We start by downloading the model and the tokenizer. We also include a `BitsAndBytesConfig` for weight only quantization.
```python
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])
```
Now we test the model before starting the finetuning, using a famous quote:
```python
text = "Quote: Imagination is more"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
The model does a reasonable completion with some extra tokens:
```
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
-Albert Einstein
I
```
But this is not exactly the format we would love the answer to be. Lets see if we can use fine-tuning to teach the model to generate the answer in the following format.
```
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
Author: Albert Einstein
```
To begin with, let's select an English quotes dataset [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes).
```python
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
```
Now lets finetune this model using the LoRA config stated above:
```python
import transformers
from trl import SFTTrainer
def formatting_func(example):
text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}<eos>"
return [text]
trainer = SFTTrainer(
model=model,
train_dataset=data["train"],
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=10,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit"
),
peft_config=lora_config,
formatting_func=formatting_func,
)
trainer.train()
```
Finally, we are ready to test the model once more with the same prompt we have used earlier:
```python
text = "Quote: Imagination is"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
This time we get the response in the format we like:
```
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
Author: Albert Einstein
```
## Accelerate with FSDP via SPMD on TPU
As mentioned earlier, Hugging Face `transformers` now supports PyTorch/XLAs latest FSDP implementation. This can greatly accelerate the fine-tuning speed. To enable that, one just needs to add a FSDP config to the `transformers.Trainer`:
```python
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
fsdp_config = {
"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
"xla": True,
"xla_fsdp_v2": True,
"xla_fsdp_grad_ckpt": True
}
# Finally, set up the trainer and train the model.
trainer = Trainer(
model=model,
train_dataset=data,
args=TrainingArguments(
per_device_train_batch_size=64, # This is actually the global batch size for SPMD.
num_train_epochs=100,
max_steps=-1,
output_dir="./output",
optim="adafactor",
logging_steps=1,
dataloader_drop_last = True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
```
## Next Steps
We walked through this simple example adapted from the source notebook to illustrate the LoRA finetuning method applied to Gemma models. The full colab for GPU can be found [here](https://huggingface.co/google/gemma-7b/blob/main/examples/notebook_sft_peft.ipynb), and the full script for TPU can be found [here](https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py). We are excited about the endless possibilities for research and learning thanks to this recent addition to our open source ecosystem. We encourage users to also visit the [Gemma documentation](https://huggingface.co/docs/transformers/v4.38.0/en/model_doc/gemma), as well as our [launch blog](https://huggingface.co/blog/gemma) for more examples to train, finetune and deploy Gemma models.
+764
View File
@@ -0,0 +1,764 @@
---
title: "Welcome Gemma 4: Frontier multimodal intelligence on device"
thumbnail: /blog/assets/gemma4/thumbnail.png
authors:
- user: merve
- user: pcuenq
- user: sergiopaniego
- user: burtenshaw
- user: Steveeeeeeen
- user: alvarobartt
- user: SaylorTwift
---
# Welcome Gemma 4: Frontier multimodal intelligence on device
The Gemma 4 family of multimodal models by Google DeepMind is out on Hugging Face, with support for your favorite agents, inference engines, and fine-tuning libraries 🤗
These models are the real deal: truly open with Apache 2 licenses, high quality with pareto frontier arena scores, multimodal including audio, and sizes you can use _everywhere_ including on-device. Gemma 4 builds on advances from previous families and makes them click together. In our tests with pre-release checkpoints we have been impressed by their capabilities, to the extent that we struggled to find good fine-tuning examples because they are _so good_ out of the box.
We collaborated with Google and the community to make them available everywhere: transformers, llama.cpp, MLX, WebGPU, Rust; you name it. This blog post will show you how to build with [your favorite tools](https://huggingface.co/collections/google/gemma-4) so let us know what you think!
## Table of Contents
- [What is New with Gemma 4?](#what-is-new-with-gemma-4)
- [Overview of Capabilities and Architecture](#overview-of-capabilities-and-architecture)
- [Architecture at a Glance](#architecture-at-a-glance)
- [Per-Layer Embeddings (PLE)](#per-layer-embeddings-ple)
- [Shared KV Cache](#shared-kv-cache)
- [Multimodal Capabilities](#multimodal-capabilities)
- [Deploy Anywhere](#deploy-anywhere)
- [transformers](#transformers)
- [Llama.cpp](#llamacpp)
- [Plug in to your local agent](#Plug-in-your-local-agent)
- [transformers.js](#transformersjs)
- [MLX](#mlx)
- [Mistral.rs](#mistralrs)
- [Fine-tuning & Demos](#fine-tuning--demos)
- [Fine-tuning with TRL](#fine-tuning-with-trl)
- [Fine-tuning with TRL on Vertex AI](#fine-tuning-with-trl-on-vertex-ai)
- [Fine-tuning with Unsloth Studio](#fine-tuning-with-unsloth-studio)
- [Try Gemma 4](#try-gemma-4)
- [Benchmark Results](#benchmark-results)
- [Acknowledgements](#acknowledgements)
# What is new with Gemma 4?
Similar to Gemma-3n, Gemma 4 supports image, text, and audio inputs, and generates text responses. The text decoder is based on the Gemma model with support for long context windows. The image encoder is similar to the one from Gemma 3 but with two crucial improvements: variable aspect ratios, and configurable number of image token inputs to find your sweet spot between speed, memory, and quality. All models support images (or video) and text inputs, while the small variants (E2B and E4B) support audio as well.
Gemma 4 comes in four sizes, all base and instruction fine-tuned:
| Model | Parameter Size | Context Window | Checkpoints |
| :---- | :---- | :---- | :---- |
| Gemma 4 E2B | 2.3B effective, 5.1B with embeddings | 128k | [base](https://huggingface.co/google/gemma-4-E2B), [IT](https://huggingface.co/google/gemma-4-E2B-it) |
| Gemma 4 E4B | 4.5B effective, 8B with embeddings | 128k | [base](https://huggingface.co/google/gemma-4-E4B), [IT](https://huggingface.co/google/gemma-4-E4B-it) |
| Gemma 4 31B | 31B dense model | 256K | [base](https://huggingface.co/google/gemma-4-31B), [IT](https://huggingface.co/google/gemma-4-31B-it) |
| Gemma 4 26B A4B | mixture-of-experts with 4B activated/26B total parameters | 256K | [base](https://huggingface.co/google/gemma-4-26B-A4B), [IT](https://huggingface.co/google/gemma-4-26B-A4B-it) |
## Overview of Capabilities and Architecture
Gemma 4 leverages several architecture components used in previous Gemma versions and other open models, and leaves out complex or inconclusive features such as Altup. The combination is a mix designed to be highly compatible across libraries and devices, that can efficiently support long context and agentic use cases, whilst being ideal for quantization.
As shown in the benchmarks above, this feature mix (combined with the training data and recipe) enables the 31B dense model to achieve an estimated LMArena score (text only) of 1452, while the 26B MoE reaches 1441 with just 4B active parameters 🤯. As we'll see, multimodal operation is comparatively as good as text generation, at least in informal and subjective tests.
These are the main architecture characteristics in Gemma 4:
* Alternating **local sliding-window** and **global full-context** attention layers. Smaller dense models use sliding windows of 512 tokens while larger models use 1024 tokens.
* **Dual RoPE** configurations: standard RoPE for sliding layers, pruned RoPE for global layers, to enable longer context.
* **Per-Layer Embeddings (PLE)**: a second embedding table that feeds a small residual signal into every decoder layer.
* **Shared KV Cache**: the last N layers of the model reuse key-value states from earlier layers, eliminating redundant KV projections.
* **Vision encoder**: uses learned 2D positions and multidimensional RoPE. Preserves the original aspect ratios and can encode images to a few different token budgets (70, 140, 280, 560, 1120).
* **Audio encoder**: USM-style conformer with the same base architecture as the one in Gemma-3n.
#### Per-Layer Embeddings (PLE)
One of the most distinctive features in smaller Gemma 4 models is Per-Layer Embeddings (PLE), which was introduced previously in Gemma-3n. In a standard transformer, each token gets a single embedding vector at input, and the same initial representation is what the residual stream builds on across all layers, forcing the embedding to frontload everything the model might need. PLE adds a parallel, lower-dimensional conditioning pathway alongside the main residual stream. For each token, it produces a small dedicated vector for every layer by combining two signals: a token-identity component (from an embedding lookup) and a context-aware component (from a learned projection of the main embeddings). Each decoder layer then uses its corresponding vector to modulate the hidden states via a lightweight residual block after attention and feed-forward. This gives each layer its own channel to receive token-specific information only when it becomes relevant, rather than requiring everything to be packed into a single upfront embedding. Because the PLE dimension is much smaller than the main hidden size, this adds meaningful per-layer specialization at modest parameter cost. For multimodal inputs (images, audio, video), PLE is computed before soft tokens are merged into the embedding sequence — since PLE relies on token IDs that are lost once multimodal features replace the placeholders. Multimodal positions use the pad token ID, effectively receiving neutral per-layer signals.
#### Shared KV Cache
The **shared KV cache** is an efficiency optimization that reduces both compute and memory during inference. The last `num_kv_shared_layers` layers of the model don't compute their own key and value projections. Instead, they **reuse** the K and V tensors from the last non-shared layer of the same attention type (sliding or full).
In practice, this has a minimal impact on quality while being much more efficient (in terms of both memory and compute) for long context generation and on-device use.
## Multimodal Capabilities
We saw in our tests that Gemma 4 supports comprehensive multimodal capabilities out of the box. We don't know what was the training mix, but we had success using it for tasks such as OCR, speech-to-text, object detection, or pointing. It also supports text-only and multimodal function calling, reasoning, code completion and correction.
Here, we show a few inference examples across different model sizes. You can run them conveniently with [this notebook](https://github.com/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb). We encourage you to try the demos and share them below this blog!
### Object Detection and Pointing
### GUI detection
We test Gemma 4 on GUI element detection and pointing across different sizes, with the following image and text prompt: "What's the bounding box for the "view recipe" element in the image?"
![Image](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/food_resized.png)
With this prompt, the model natively responds in JSON format with the detected bounding boxes - no need for specific instructions or grammar-constrained generation. We found the coordinates refer to an image size of 1000x1000, relative to the input dimensions.
We visualize the outputs below for your convenience. We parse the bounding boxes from the returned JSON: ```json\n[\n {"box_2d": [171, 75, 245, 308], "label": "view recipe element"}\n]\n```
| E2B | E4B |
| :---- | :---- |
| ![E2B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/e2b.png) | ![E4B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/e4b.png) |
| 26/A4B | 31B |
| ![31B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/26b.png) | ![31B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/31b.png) |
### Object Detection
We test models to detect everyday objects, here we ask them to detect the bike and compare different model outputs. As in the previous case, we parse the bounding box from the json and translate to image space coordinates.
| E2B | E4B | 26B/A3B | 31B |
| :---- | :---- | :---- | :---- |
| ![E2B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/e2b_bike.png) | ![E4B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/bike_e4b.png) | ![26B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/bike_26b.png) | ![31B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/31b_bike.png) |
### Multimodal Thinking and Function Calling
We asked Gemma 4 to write HTML code to reconstruct a page we made with Gemini 3. Below you can find the code to do this, we enable thinking and ask each model to generate up to 4000 new tokens, to make it foolproof.
| Gemini Generated Website (Reference) | Gemini Reproduced Image |
| :---- | :---- |
| ![Reference](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/landing_page.png) | ![Reference](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/page_repro_gemini.png) |
<details>
<summary>Inference code</summary>
```py
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png",
},
{"type": "text", "text": "Write HTML code for this page."},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
enable_thinking=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=4000)
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
result = processor.parse_response(generated_text)
print(result["content"])
```
</details>
| E2B | E4B |
| :---- | :---- |
| ![Reference](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/page_e2b.png) | ![E4B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/page_e4b.png) |
| 31B | 26/A4B |
| :---- | :---- |
| ![31B](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/page_31b.png) | ![MoE](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/page_26b.png) |
### Video Understanding
Smaller Gemma 4 models can take in videos with audio while larger ones can take in videos without audio. While the models are not explicitly post-trained on videos, they can understand videos both with and without audio. The model is particularly strong in audios
<details>
<summary>Inference code</summary>
```py
messages = [
{
"role": "user",
"content": [
{"type": "video", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"},
{"type": "text", "text": "What is happening in the video? What is the song about?"},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
load_audio_from_video=True, # disable this for larger models
).to(model.device)
output = model.generate(**inputs, max_new_tokens=200)
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
print(result["content"])
```
</details>
| Model | Output |
| :---- | :---- |
| E2B (hallucinated for the audio) | Based on the video provided, here is an analysis: **What is happening in this video?** The video captures a live musical performance, likely at a large outdoor festival or concert. * **Setting:** The scene is a large stage set up outdoors, indicated by the extensive lighting rig, large speakers, and the audience visible in the foreground. * **Performers:** There are several musicians on stage. The focal point seems to be a lead singer/guitarist in a blue shirt and white pants who is actively performing, playing an acoustic or electric guitar. There are also other musicians, including drummers and bassists, visible in the background. * **Atmosphere:** The atmosphere is energetic and dramatic, characterized by bright stage lights, intense blue and white spotlights, and a significant amount of stage fog/smoke, which adds to the concert mood. * **Audience:** The foreground is filled with the backs of audience members, showing they are close to the stage and engaged in watching the performance. **What is the song about?** The lyrics provided below the video are: > "I don't care 'bout street > From what I cherish > Could it's time to protest |
| E4B ✅ | The video captures a live performance of a band on a stage with bright lights and smoke effects, surrounded by an audience. The band performs and sings to the crowd. The song appears to be about the struggles and disillusionment of modern life, specifically the feeling of being stuck and the difficulties in finding one's place in a chaotic and overwhelming world. Lyrics mention themes of the "concrete street," "broken promises," "storm of the day," and the difficulty of finding oneself amidst societal pressures. |
| 26/A4B (no audio) ✅ | This video shows a live music performance at an outdoor concert or festival at night. A solo male artist, dressed in a blue shirt and white pants, is the central focus. He is playing an acoustic guitar and singing into a microphone on a large, lit stage. The performance is accompanied by a band, including a drummer and other musicians visible in the background and on large video screens. The stage is filled with atmospheric blue lighting, smoke, and bright spotlights. The video captures different perspectives, including shots from the crowd looking up at the stage and close-up views of the performers projected onto large screens. The crowd is visible in the foreground, watching the show. One of the screens also displays the "Estrella Damm" brand name. |
| 31B (no audio) ✅ | A musician is performing on stage during a concert, playing an acoustic guitar and singing into a microphone. He is wearing a blue shirt and white pants. The stage is filled with atmospheric smoke and illuminated by bright blue and white lights. Other band members, including a drummer and a keyboardist, are also visible on stage. The video shows the musician from the perspective of the audience, with some shots focusing on him and others showing the large screen on the side of the stage, which displays close-ups of the performer and the crowd. |
### Captioning
We have tested all models on captioning. All checkpoints perform very well and accurately capture nuance in complex scenerios.
Here's the image we use, with the prompt "Write single detailed caption for this image.".
![image](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png)
<details>
<summary>Inference code</summary>
```py
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png"},
{"type": "text", "text": "Write single detailed caption for this image."},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=512)
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
result = processor.parse_response(generated_text)
print(result["content"])
```
</details>
| Model | Output |
| :---- | :---- |
| E2B ✅ | A medium shot captures a weathered seagull perched atop a stone pedestal in what appears to be a bustling European square, with a grand, classical-style building featuring ornate columns and architectural details dominating the right side of the frame. In the background, a low, light-colored building stretches across the middle ground, flanked by some green foliage, and hints of other structures, including a terracotta-roofed building, are visible. In the foreground, part of a modern bus stop or shelter is visible, with signage indicating routes, and the overall scene is bathed in bright daylight under a pale sky. |
| E4B ✅ | A solitary seagull perches atop a grey stone pedestal in a bustling European piazza, with the historical grandeur of a classical building dominating the right side of the frame, contrasting with the more modern, low-rise architecture in the background under a bright, hazy sky. |
| 26B/A4B ✅ | A full shot shows a white and brown seagull perched on a gray pillar in Venice, Italy. The seagull faces left, its beak pointed towards the edge of the frame. Behind it is a light gray, two-story building with several dark rectangular windows. To the right of the seagull is a large, white building with columns and intricate carvings. The sky is a clear, light blue. The lighting is bright and even, suggesting it is daytime. |
| 31B ✅ | A high-angle, medium shot shows a white and gray seagull perched on a gray concrete pylon in Venice, Italy. The seagull is facing left, its body angled slightly towards the camera. Its white head, neck, and chest contrast with its gray wings and back. The seagull's orange beak is slightly open. The pylon is a light gray concrete cylinder with a weathered texture. In the background, there are several buildings. To the left is a long, low, white building with a series of small, rectangular windows. In the center is a cluster of green trees. To the right is a red building with a tiled roof and several small, rectangular windows. Further right is a large, ornate, white stone building with several columns and sculptures. The sky is a pale blue, and the light is bright and clear. In the lower-left corner, there is a yellow bus with the word "Ferrovia" printed on it in black letters. There are several people walking in the background, but they are small and out of focus. |
### Audio Question Answering
These models are trained to answer questions about speech in audio. Music and non-speech sounds were not part of the training data.
<details>
<summary>Inference code</summary>
```py
messages = [
{
"role": "user",
"content": [
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
{"type": "text", "text": "Can you describe this audio in detail?"},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=False,
)
print(processor.decode(output[0], skip_special_tokens=True))
```
</details>
| Model | Output |
| :---- | :---- |
| E2B | This audio is a personal reflection. The speaker is talking about their final farewell address to the nation, which they delivered in Chicago. They express gratitude for the conversations they've had with the American people, noting that despite not having met them face-to-face or even greeted them, these interactions in various settings like living rooms, schools, farms, factory floors, diners, and military outposts have been what has kept them going. |
| E4B | The audio is a speech excerpt where a speaker is delivering a farewell address to the nation from Chicago. The speaker reflects on their time in office, expressing gratitude for the conversations they had with the American people across various settings like living rooms, schools, farms, factories, diners, and military outposts. The tone is reflective and appreciative, highlighting the importance of these interactions in their political journey. |
Here is an example if you want to do transcription:
<details>
<summary>Inference code</summary>
```py
messages = [
{
"role": "user",
"content": [
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
{"type": "text", "text": "Transcribe the audio?"},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=False,
)
print(processor.decode(output[0], skip_special_tokens=True))
```
</details>
| Model | Output |
| :---- | :---- |
| E2B | This week I traveled to Chicago to deliver my final farewell address to the nation following in the tradition of presidents before me It was an opportunity to say thank you whether we've seen eye to eye or rarely agreed at all my conversations with you the American people in living rooms and schools at farms and on factory floors at diners and on distant military outposts all these conversations are what have kept me honest |
| E4B | This week I traveled to Chicago to deliver my final farewell address to the nation following in the tradition of presidents before me. It was an opportunity to say thank you. Whether we've seen eye to eye or rarely agreed at all, my conversations with you, the American people, in living rooms and schools, at farms and on factory floors, at diners and on distant military outposts, all these conversations are what have kept me honest. |
### Multimodal Function Calling
We test the model by asking to get the weather in the place shown in the image.
<details>
<summary>Inference code</summary>
```py
import re
WEATHER_TOOL = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather for a specific location.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city name"},
},
"required": ["city"],
},
},
}
tools = [WEATHER_TOOL]
messages = [
{"role": "user", "content": [
{"type": "text", "text": "What is the city in this image? Check the weather there right now."},
{"type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg"},
]},
]
inputs = processor.apply_chat_template(
messages,
tools=[WEATHER_TOOL],
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
enable_thinking=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=1000)
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
result = processor.parse_response(generated_text)
print(result["content"])
```
</details>
| Model | Output |
| :---- | :---- |
| E2B | **Analyze the Request:** The user has provided an image and asked two questions:\n * "What is the city in this image?" * "Check the weather there right now."\n\n2. **Analyze the Image:** The image displays a large, ornate, tall structure, characteristic of temple architecture, likely in Southeast Asia (given the style). It appears to be a famous landmark. 3. **Identify the Landmark/Location (Internal Knowledge/Visual Cues):** The structure strongly resembles the chedis/stupas found in places like Bangkok, Thailand (e.g., Wat Arun or similar structures).\n\n4. **Determine the City:** Based on the visual evidence, the location is highly likely to be Bangkok, Thailand.\n\n5. **Check Available Tools:** The available tool is `get_weather(city: STRING)`. 6. **Formulate the Plan:**\n * Step 1: Identify the city from the image (Bangkok).\n * Step 2: Use the `get_weather` tool with the identified city.\n\n7. **Construct the Tool Call:** `get_weather(city="Bangkok")`\n\n8. **Generate Response:** The model should first state the identified city and then call the tool to get the weather. (Self-correction: The model *must* call the tool in the correct format if it needs external information.)call:get_weather{city:Bangkok} |
| E4B | '1. **Analyze the Request:** The user has provided an image and asked two things:\n a. Identify the city in the image.\n b. Check the weather in that city right now.\n\n2. **Analyze the Image:** The image shows a very distinctive, ornate tower/spire structure, characteristic of Thai architecture (specifically, a chedi or prang).\n\n3. **Determine City Identification (Internal Knowledge/Visual Recognition):** Based on the architectural style, the structure strongly resembles prominent temples in Thailand, such as Wat Arun in Bangkok.\n\n4. **Evaluate Tool Availability:** I have one available tool: `get_weather(city: STRING)`. This tool requires a city name as input.\n\n5. **Formulate a Plan:**\n a. I need the city name to use the `get_weather` tool.\n b. I must first identify the city from the image.\n c. Since my internal knowledge suggests Bangkok, I will proceed with that assumption, but I must acknowledge that I cannot *definitively* prove the location from the image alone.\n\n6. **Execute Step 1 (City Identification):** Assume the city is Bangkok based on the visual evidence of the architecture.\n\n7. **Execute Step 2 (Tool Call Generation):** Use the `get_weather` tool with "Bangkok" as the `city` argument.\n\n8. **Construct the Response:** Generate the appropriate tool call in JSON format.', 'answer': '<\|tool_call>call:get_weather{city:<\|"\|>Bangkok<\|"\|>}<tool_call\|><eos>' |
| 31B | The user wants to know the city in the image and the current weather there. 1. **Analyze the image:** The image shows ornate, spire-like structures (prangs) characteristic of Thai temple architecture. Specifically, these look like the prangs at Wat Phra Kaew (Temple of the Emerald Buddha) or Wat Pho in Bangkok, Thailand. The style is very distinctive of Bangkok's royal temple complexes. 2. **Identify the city:** Based on the architectural style, the city is Bangkok. 3. **Find the weather:** I need to call the `get_weather` tool with `city="Bangkok"`.call:get_weather{city:Bangkok} |
| 26/A4B | The image shows several tall, ornate towers (prangs) typical of Thai Buddhist temple architecture. Specifically, these look like the structures within the Grand Palace or Wat Phra Kaew in Bangkok, Thailand. The style is unmistakably Thai. I'll confirm this by looking at the specific towers. These are very iconic.\n2. **Confirm the location:** The image shows elements of the Grand Palace complex in Bangkok.\n3. **Get weather information:** Once the city is identified as Bangkok, I will use the `get_weather` tool with `city='Bangkok'`.call:get_weather{city:Bangkok} |
# Deploy Anywhere
Gemma 4 comes with day-0 support for many open-source inference engines, and is ideal for tool calling and agents! We also release ONNX checkpoints that can run on many hardware backends, allowing use cases on edge devices or in browser!
## transformers
Gemma 4 comes with first-class transformers support from the get-go 🤗. This integration allows using the model with other libraries like bitsandbytes, PEFT and TRL. Make sure to install the latest version of transformers.
```bash
pip install -U transformers
```
The easiest way to infer with the small Gemma 4 models is through the `any-to-any` pipeline. You can initialize it as follows.
```py
from transformers import pipeline
pipe = pipeline("any-to-any", model="google/gemma-4-e2b-it")
```
You can then pass in images and text as follows.
```python
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg",
},
{"type": "text", "text": "Do you have travel advice going to here?"},
],
}
]
output = pipe(messages, max_new_tokens=100, return_full_text=False)
output[0]["generated_text"]
# Based on the image, which appears to show a magnificent, ornate **Buddhist temple or pagoda**, likely in Southeast Asia (such as Thailand, Myanmar, or Cambodia), here is some general travel advice..
```
When inferring with videos, you can include the audio track using the `load_audio_from_video` argument.
```python
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
},
{"type": "text", "text": "What is happening in this video?"},
],
}
]
pipe(messages, load_audio_from_video=True)
```
Going a level lower, you can load Gemma 4 using the `AutoModelForMultimodalLM` class, especially useful for fine-tuning. The built-in chat template takes care of formatting the inputs correctly, please make sure you use it to prevent subtle mistakes when building the prompt manually.
<details>
<summary>Inference code</summary>
```python
from transformers import AutoModelForMultimodalLM, AutoProcessor
model = AutoModelForMultimodalLM.from_pretrained("google/gemma-4-E2B-it", device_map="auto")
processor = AutoProcessor.from_pretrained("google/gemma-4-E2B-it")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
},
{"type": "text", "text": "What is happening in this video?"},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```
</details>
## Llama.cpp
Gemma 4 models come with image+text support in llama.cpp from the get-go! This unlocks using Gemma 4 with all of your favorite local apps: llama-cpp server, lmstudio, Jan as well as coding agents like Pi across many backends such as Metal and CUDA.
You can install llama-cpp as follows.
```bash
brew install llama.cpp # MacOS
winget install llama.cpp # Windows
```
You can then start a server compatible with the OpenAI API Replace the quantization scheme at the end of the command with the precision of your choice.
```bash
llama-server -hf ggml-org/gemma-4-E2B-it-GGUF
```
Check out this link [for more](https://huggingface.co/ggml-org/gemma-4-E2B-it-GGUF?local-app=llama.cpp) options on combining llama.cpp with different coding agents and local apps. Find all the GGUF checkpoints [in this collection](https://huggingface.co/collections/ggml-org/gemma-4).
## Plug in your local agent
We worked on making sure the new models work locally with agents like **openclaw, hermes, pi, and open code**. All thanks to llama.cpp! Run the following to try Gemma 4 right away.
First, start your local server:
```
llama-server -hf ggml-org/gemma-4-26b-a4b-it-GGUF:Q4_K_M
```
For **hermes:**
```shell
hermes model
```
For **openclaw:**
```shell
openclaw onboard
```
For **pi** define a `~/.pi/agent/models.json`:
```json
{
"providers": {
"llama-cpp": {
"baseUrl": "http://localhost:8080/v1",
"api": "openai-completions",
"apiKey": "none",
"models": [
{
"id": "ggml-org-gemma-4-26b-4b-gguf"
}
]
}
}
}
```
For **open code** define a `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"llama.cpp": {
"npm": "@ai-sdk/openai-compatible",
"name": "llama-server (local)",
"options": {
"baseURL": "http://127.0.0.1:8080/v1"
},
"models": {
"gemma-4-26b-4b-it": {
"name": "Gemma 4 (local)",
"limit": {
"context": 128000,
"output": 8192
}
}
}
}
}
}
```
## transformers.js
transformers.js enables running Gemma 4 right inside browser. You can check out the model card to see text-only, image & text, audio & text inference in detail [here](https://huggingface.co/onnx-community/gemma-4-E2B-it-ONNX#transformersjs-javascript). We also shipped a demo for you to test the model [here](https://huggingface.co/spaces/webml-community/Gemma-4-WebGPU).
## MLX
Full multimodal support of Gemma 4 is available using the open-source [`mlx-vlm` library](https://github.com/Blaizzy/mlx-vlm). Here's how to ask the model to describe an image:
```shell
pip install -U mlx-vlm
```
```shell
mlx_vlm.generate \
--model google/gemma-4-E4B-it \
--image https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg \
--prompt "Describe this image in detail"
```
mlx-vlm supports TurboQuant, which delivers the same accuracy as the uncompressed baseline while using ~4x less active memory and running a lot faster end-to-end. This makes long-context inference practical on Apple Silicon without sacrificing quality. Use it like this:
```shell
mlx_vlm.generate \
--model "mlx-community/gemma-4-26b-a4b-it-4bit" \
--prompt "Your prompt here" \
--kv-bits 3.5 \
--kv-quant-scheme turboquant
```
For audio examples and more details, please check [the MLX collection](https://hf.co/mlx-community/gemma-4).
### Mistral.rs
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a Rust-native inference engine with day-0 Gemma 4 support across all modalities (text, image, video, audio) and builtin tool-calling and agentic functionality. Install mistral.rs:
```bash
curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/EricLBuehler/mistral.rs/master/install.sh | sh # Linux/macOS
irm https://raw.githubusercontent.com/EricLBuehler/mistral.rs/master/install.ps1 | iex # Windows
```
You can then start an OpenAI-compatible HTTP server:
```bash
mistralrs serve mistralrs-community/gemma-4-E4B-it-UQFF --from-uqff 8
```
Or, use interactive mode:
```
mistralrs run -m google/gemma-4-E4B-it --isq 8 --image image.png -i "Describe this image in detail."
mistralrs run -m google/gemma-4-E4B-it --isq 8 --audio audio.mp3 -i "Transcribe this fully."
```
Find all models [here](https://huggingface.co/mistralrs-community/models). Please, follow [the instructions](https://huggingface.co/mistralrs-community/gemma-4-E2B-it-UQFF#install) in the model cards for installation and inference guidelines.
## Fine-tuning for all
Gemma 4 models are ideal for fine-tuning in your favorite tools and platforms and at any budget.
## Fine-tuning with TRL
Gemma 4 is fully supported for fine-tuning with TRL. To celebrate, TRL has been upgraded with support for multimodal tool responses when interacting with environments, meaning models can now receive images back from tools during training, not just text.
To showcase this, we've built an example training script where Gemma 4 learns to drive in the CARLA simulator. The model sees the road through a camera, decides what to do and learns from the outcome. After training, it consistently changes lanes to avoid pedestrians. The same approach works for any task where a model needs to see and act: robotics, web browsing, or other interactive environments.
Get started:
```shell
# pip install git+https://github.com/huggingface/trl.git
python examples/scripts/openenv/carla_vlm_gemma.py \
--env-urls https://sergiopaniego-carla-env.hf.space \
https://sergiopaniego-carla-env-2.hf.space \
--model google/gemma-4-E2B-it
```
Find the example [here](https://github.com/huggingface/huggingface-gemma-recipes/blob/main/scripts/carla_vlm_gemma.py).
### Fine-tuning with TRL on Vertex AI
Additionally, we have prepared an example on how to fine-tune Gemma 4 with TRL on Vertex AI using SFT, to showcase how to extend the function calling capabilities, whilst freezing both the vision and audio towers. The examples include how to build a custom Docker container with latest Transformers, TRL, etc. with CUDA support on Google Cloud, and how to run it via Vertex AI Serverless Training Jobs.
```python
# pip install google-cloud-aiplatform --upgrade --quiet
from google.cloud import aiplatform
aiplatform.init(
project="<PROJECT_ID>",
location="<LOCATION>",
staging_bucket="<BUCKET_URI>",
)
job = aiplatform.CustomContainerTrainingJob(
display_name="gemma-4-fine-tuning",
container_uri="<CONTAINER_URI>",
command=["python", "/gcs/gemma-4-fine-tuning/train.py"],
)
job = job.submit(
replica_count=1,
machine_type="a3-highgpu-1g",
accelerator_type="NVIDIA_H100_80GB",
accelerator_count=1,
base_output_dir="<BUCKET_URI>/output-dir",
environment_variables={
"MODEL_ID": "google/gemma-4-E2B-it",
"HF_TOKEN": <HF_TOKEN>,
},
boot_disk_size_gb=500,
)
```
You can find the complete example in the "Hugging Face on Google Cloud" docs at https://hf.co/docs/google-cloud/examples/vertex-ai-notebooks-fine-tune-gemma-4.
## Fine-tuning with Unsloth Studio
If you want to fine tune and run a Gemma 4 model in a UI, try out [Unsloth Studio](https://unsloth.ai/docs/new/studio). It runs locally or on Google Colab. First, install and start the app:
```shell
# install unsloth studio on MacOS, Linux, WSL
curl -fsSL https://unsloth.ai/install.sh | sh
# install unsloth studio on Windows
irm https://unsloth.ai/install.ps1 | iex
# launch unsloth studio
unsloth studio -H 0.0.0.0 -p 8888
# Search for for a Gemma 4 model like google/gemma-4-E2B-it
```
Then select any of the Gemma 4 models from the hub.
![Unsloth Studio](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma4/unsloth.png)
## Try Gemma 4
We have shipped demos for you to try different Gemma 4 models. We include demos based on transformers implementation for [E4B](https://huggingface.co/spaces/huggingface-projects/gemma-4-e4b-it), [26B/A4B](https://huggingface.co/spaces/huggingface-projects/gemma-4-26b-a4b-it), and dense [31B](https://huggingface.co/spaces/huggingface-projects/gemma-4-31b-it) models, as well as a [WebGPU](https://huggingface.co/spaces/webml-community/Gemma-4-WebGPU) demo with transformers.js 🚀
<iframe width="560" height="315" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/webgpu_demo.mp4" title="WebGPU Demo" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen> </iframe>
## Benchmark Results
Gemma 4 models demonstrate exceptional performance across diverse benchmarks, from reasoning and coding to vision and long-context tasks. The graph below shows model performance vs size, with Gemma 4 models forming an impressive Pareto frontier:
<div style="display: flex; gap: 20px; justify-content: center; align-items: flex-start; flex-wrap: wrap;">
<figure style="flex: 1; min-width: 300px; text-align: center;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/g4_graph.png" alt="Gemma 4 Performance vs Size" style="width: 100%; height: 400px; object-fit: contain;">
</figure>
<figure style="flex: 1; min-width: 300px; text-align: center;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/g4-blog/g4_graph_2.png" alt="Gemma 4 Arena Elo Score Comparison" style="width: 100%; height: 400px; object-fit: contain;">
</figure>
</div>
<p style="text-align: center; font-size: 0.9em; color: #666;">Source: Google (<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/">blog.google</a>)</p>
Here are detailed benchmark results for the instruction-tuned models:
| Benchmark | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
|-----------|-------------|-----------------|-------------|-------------|------------------------|
| **Reasoning & Knowledge** |
| MMLU Pro | [85.2%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-31B-it) | [82.6%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-26B-A4B-it) | [69.4%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-E4B-it) | [60.0%](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro?eval_result=google/gemma-4-E2B-it) | 67.6% |
| AIME 2026 no tools | [89.2%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-31B-it) | [88.3%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-26B-A4B-it) | [42.5%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-E4B-it) | [37.5%](https://huggingface.co/datasets/MathArena/aime_2026?eval_result=google/gemma-4-E2B-it) | 20.8% |
| GPQA Diamond | [84.3%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-31B-it) | [82.3%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-26B-A4B-it) | [58.6%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-E4B-it) | [43.4%](https://huggingface.co/datasets/Idavidrein/gpqa?eval_result=google/gemma-4-E2B-it) | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Coding** |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| HLE no tools | [19.5%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-31B-it) | [8.7%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-26B-A4B-it) | - | - | - |
| HLE with search | [26.5%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-31B-it) | [17.2%](https://huggingface.co/datasets/cais/hle?eval_result=google/gemma-4-26B-A4B-it) | - | - | - |
| **Vision** |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (edit distance) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## Acknowledgements
Landing Gemma-4 in the open-source ecosystem took a lot of effort from many people and not only the authors of this blog post. In no particular order, we thank many people from the open-source team: Gemma 4 transformers integration is owed to Cyril, Raushan, Eustache, Arthur, Lysandre. We thank Joshua for transformers.js integration and demo, Eric for mistral.rs integration, Son for Llama.cpp, Prince for MLX integration, Quentin, Albert and Kashif for TRL, Adarsh for SGLang transformers backend and Toshihiro for building the demos.
This work wouldn't have been possible without Google's extensive contribution with the model artefact, but also the significant effort contributing the model to transformers in an effort to standardize it. The open-source ecosystem is now more complete, with a very capable, freely-licensed, open-source model.
The Gemma 4 transformers integration was handled by Cyril, Raushan, Eustache, Arthur, Lysandre. We thank Joshua for the transformers.js integration and demo, Eric for mistral.rs integration, Son for Llama.cpp, Prince for MLX, Quentin for TRL, Adarsh for SGLang transformers backend, and Toshihiro for building several demos.
This work wouldn't have been possible without Google's extensive contribution with the model artefact, but also their significant effort contributing the model to transformers in an effort to standardize it. The open-source ecosystem is now more complete, with a very capable, freely-licensed, open-source model.
@@ -0,0 +1,514 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: image-text-to-text
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-26B-A4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-26B-A4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-26B-A4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,513 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: image-text-to-text
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-26B-A4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-26B-A4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-26B-A4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,513 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: image-text-to-text
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-31B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-31B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-31B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,513 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: image-text-to-text
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-31B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-31B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision torchcodec librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-31B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,347 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- set add_comma = false -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{ key }}:{
{%- if value['description'] -%}
description:<|"|>{{ value['description'] }}<|"|>
{%- set add_comma = true -%}
{%- endif -%}
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'] | dictsort -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<|"|>{{- req_item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
{%- if value['nullable'] %}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
nullable:true
{%- endif -%}
{%- if value['type'] | upper == 'OBJECT' -%}
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
}
{%- elif value is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value, value['required'] | default([])) -}}
}
{%- endif -%}
{%- if value['required'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
required:[
{%- for item in value['required'] | default([]) -%}
<|"|>{{- item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- endif -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
type:<|"|>{{ value['type'] | upper }}<|"|>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{%- macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<|"|>{{- item -}}<|"|>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<|"|>{{- params['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
{%- if 'response' in tool_data['function'] -%}
{%- set response_declaration = tool_data['function']['response'] -%}
,response:{
{%- if response_declaration['description'] -%}
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
{%- endif -%}
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{%- macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<|"|>' + argument + '<|"|>' -}}
{%- elif argument is boolean -%}
{{- 'true' if argument else 'false' -}}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<|"|>' + key + '<|"|>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is sequence -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- macro format_tool_response_block(tool_name, response) -%}
{{- '<|tool_response>' -}}
{%- if response is mapping -%}
{{- 'response:' + tool_name + '{' -}}
{%- for key, value in response | dictsort -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- '}' -}}
{%- else -%}
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
{%- endif -%}
{{- '<tool_response|>' -}}
{%- endmacro -%}
{%- set ns = namespace(prev_message_type=None) -%}
{%- set loop_messages = messages -%}
{{- bos_token -}}
{#- Handle System/Tool Definitions Block -#}
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
{{- '<|turn>system\n' -}}
{#- Inject Thinking token at the very top of the FIRST system turn -#}
{%- if enable_thinking is defined and enable_thinking -%}
{{- '<|think|>\n' -}}
{%- set ns.prev_message_type = 'think' -%}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{- messages[0]['content'] | trim -}}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- for tool in tools %}
{{- '<|tool>' -}}
{{- format_function_declaration(tool) | trim -}}
{{- '<tool|>' -}}
{%- endfor %}
{%- set ns.prev_message_type = 'tool' -%}
{%- endif -%}
{{- '<turn|>\n' -}}
{%- endif %}
{#- Pre-scan: find last user message index for reasoning guard -#}
{%- set ns_turn = namespace(last_user_idx=-1) -%}
{%- for i in range(loop_messages | length) -%}
{%- if loop_messages[i]['role'] == 'user' -%}
{%- set ns_turn.last_user_idx = i -%}
{%- endif -%}
{%- endfor -%}
{#- Loop through messages -#}
{%- for message in loop_messages -%}
{%- if message['role'] != 'tool' -%}
{%- set ns.prev_message_type = None -%}
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
{%- set prev_nt = namespace(role=None, found=false) -%}
{%- if loop.index0 > 0 -%}
{%- for j in range(loop.index0 - 1, -1, -1) -%}
{%- if not prev_nt.found -%}
{%- if loop_messages[j]['role'] != 'tool' -%}
{%- set prev_nt.role = loop_messages[j]['role'] -%}
{%- set prev_nt.found = true -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
{%- if not continue_same_model_turn -%}
{{- '<|turn>' + role + '\n' }}
{%- endif -%}
{#- Render reasoning/reasoning_content as thinking channel -#}
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- set function = tool_call['function'] -%}
{{- '<|tool_call>call:' + function['name'] + '{' -}}
{%- if function['arguments'] is mapping -%}
{%- set ns_args = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns_args.found_first %},{% endif -%}
{%- set ns_args.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{{- function['arguments'] -}}
{%- endif -%}
{{- '}<tool_call|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- set ns_tr_out = namespace(flag=false) -%}
{%- if message.get('tool_responses') -%}
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
{%- for tool_response in message['tool_responses'] -%}
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endfor -%}
{%- elif message.get('tool_calls') -%}
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
{%- set ns_tool_scan = namespace(stopped=false) -%}
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
{%- if ns_tool_scan.stopped -%}
{%- elif loop_messages[k]['role'] != 'tool' -%}
{%- set ns_tool_scan.stopped = true -%}
{%- else -%}
{%- set follow = loop_messages[k] -%}
{#- Resolve tool_call_id to function name -#}
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
{%- for tc in message['tool_calls'] -%}
{%- if tc.get('id') == follow.get('tool_call_id') -%}
{%- set ns_tname.name = tc['function']['name'] -%}
{%- endif -%}
{%- endfor -%}
{#- Handle content as string or content-parts array -#}
{%- set tool_body = follow.get('content') -%}
{%- if tool_body is string -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- elif tool_body is sequence and tool_body is not string -%}
{%- set ns_txt = namespace(s='') -%}
{%- for part in tool_body -%}
{%- if part.get('type') == 'text' -%}
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
{%- endif -%}
{%- endfor -%}
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
{%- else -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- endif -%}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if message['content'] is string -%}
{%- if role == 'model' -%}
{{- strip_thinking(message['content']) -}}
{%- else -%}
{{- message['content'] | trim -}}
{%- endif -%}
{%- elif message['content'] is sequence -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'text' -%}
{%- if role == 'model' -%}
{{- strip_thinking(item['text']) -}}
{%- else -%}
{{- item['text'] | trim -}}
{%- endif -%}
{%- elif item['type'] == 'image' -%}
{{- '<|image|>' -}}
{%- set ns.prev_message_type = 'image' -%}
{%- elif item['type'] == 'audio' -%}
{{- '<|audio|>' -}}
{%- set ns.prev_message_type = 'audio' -%}
{%- elif item['type'] == 'video' -%}
{{- '<|video|>' -}}
{%- set ns.prev_message_type = 'video' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
{{- '<|tool_response>' -}}
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
{{- '<turn|>\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
{{- '<|turn>model\n' -}}
{%- if not enable_thinking | default(false) -%}
{{- '<|channel>thought\n<channel|>' -}}
{%- endif -%}
{%- endif -%}
{%- endif -%}
@@ -0,0 +1,74 @@
{
"audio_token": "<|audio|>",
"backend": "tokenizers",
"boa_token": "<|audio>",
"boi_token": "<|image>",
"bos_token": "<bos>",
"eoa_token": "<audio|>",
"eoc_token": "<channel|>",
"eoi_token": "<image|>",
"eos_token": "<eos>",
"eot_token": "<turn|>",
"escape_token": "<|\"|>",
"etc_token": "<tool_call|>",
"etd_token": "<tool|>",
"etr_token": "<tool_response|>",
"extra_special_tokens": [
"<|video|>"
],
"image_token": "<|image|>",
"mask_token": "<mask>",
"model_max_length": 1000000000000000019884624838656,
"pad_token": "<pad>",
"padding_side": "left",
"processor_class": "Gemma4Processor",
"response_schema": {
"type": "object",
"properties": {
"role": {
"const": "assistant"
},
"thinking": {
"type": "string"
},
"content": {
"type": "string"
},
"tool_calls": {
"x-regex-iterator": "<\\|tool_call>(.*?)<tool_call\\|>",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"const": "function"
},
"function": {
"type": "object",
"x-regex": "call\\:(?P<name>\\w+)(?P<arguments>\\{.*\\})",
"properties": {
"name": {
"type": "string"
},
"arguments": {
"type": "object",
"x-parser": "gemma4-tool-call",
"additionalProperties": {}
}
}
}
}
}
}
},
"x-regex": "(\\<\\|channel\\>thought\\n(?P<thinking>.*?)\\<channel\\|\\>)?(?P<tool_calls>\\<\\|tool_call\\>.*\\<tool_call\\|\\>)?(?P<content>(?:(?!\\<turn\\|\\>)(?!\\<\\|tool_response\\>).)+)?(?:\\<turn\\|\\>|\\<\\|tool_response\\>)?"
},
"soc_token": "<|channel>",
"sot_token": "<|turn>",
"stc_token": "<|tool_call>",
"std_token": "<|tool>",
"str_token": "<|tool_response>",
"think_token": "<|think|>",
"tokenizer_class": "GemmaTokenizer",
"unk_token": "<unk>"
}
@@ -0,0 +1,516 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: any-to-any
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,513 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: any-to-any
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E2B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,513 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: any-to-any
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,515 @@
---
library_name: transformers
license: apache-2.0
license_link: https://ai.google.dev/gemma/docs/gemma_4_license
pipeline_tag: any-to-any
---
<div align="center">
<img src=https://ai.google.dev/gemma/images/gemma4_banner.png>
</div>
<p align="center">
<a href="https://huggingface.co/collections/google/gemma-4" target="_blank">Hugging Face</a> |
<a href="https://github.com/google-gemma" target="_blank">GitHub</a> |
<a href="https://blog.google/innovation-and-ai/technology/developers-tools/gemma-4/" target="_blank">Launch Blog</a> |
<a href="https://ai.google.dev/gemma/docs/core" target="_blank">Documentation</a>
<br>
<b>License</b>: <a href="https://ai.google.dev/gemma/docs/gemma_4_license" target="_blank">Apache 2.0</a> | <b>Authors</b>: <a href="https://deepmind.google/models/gemma/" target="_blank">Google DeepMind</a>
</p>
Gemma is a family of open models built by Google DeepMind. Gemma 4 models are multimodal, handling text and image input (with audio supported on small models) and generating text output. This release includes open-weights models in both pre-trained and instruction-tuned variants. Gemma 4 features a context window of up to 256K tokens and maintains multilingual support in over 140 languages.
Featuring both Dense and Mixture-of-Experts (MoE) architectures, Gemma 4 is well-suited for tasks like text generation, coding, and reasoning. The models are available in four distinct sizes: **E2B**, **E4B**, **26B A4B**, and **31B**. Their diverse sizes make them deployable in environments ranging from high-end phones to laptops and servers, democratizing access to state-of-the-art AI.
Gemma 4 introduces key **capability and architectural advancements**:
* **Reasoning** All models in the family are designed as highly capable reasoners, with configurable thinking modes.
* **Extended Multimodalities** Processes Text, Image with variable aspect ratio and resolution support (all models), Video, and Audio (featured natively on the E2B and E4B models).
* **Diverse & Efficient Architectures** Offers Dense and Mixture-of-Experts (MoE) variants of different sizes for scalable deployment.
* **Optimized for On-Device** Smaller models are specifically designed for efficient local execution on laptops and mobile devices.
* **Increased Context Window** The small models feature a 128K context window, while the medium models support 256K.
* **Enhanced Coding & Agentic Capabilities** Achieves notable improvements in coding benchmarks alongside native function-calling support, powering highly capable autonomous agents.
* **Native System Prompt Support** Gemma 4 introduces native support for the `system` role, enabling more structured and controllable conversations.
## **Models Overview**
Gemma 4 models are designed to deliver frontier-level performance at each size, targeting deployment scenarios from mobile and edge devices (E2B, E4B) to consumer GPUs and workstations (26B A4B, 31B). They are well-suited for reasoning, agentic workflows, coding, and multimodal understanding.
The models employ a hybrid attention mechanism that interleaves local sliding window attention with full global attention, ensuring the final layer is always global. This hybrid design delivers the processing speed and low memory footprint of a lightweight model without sacrificing the deep awareness required for complex, long-context tasks. To optimize memory for long contexts, global layers feature unified Keys and Values, and apply Proportional RoPE (p-RoPE).
### Dense Models
| Property | E2B | E4B | 31B Dense |
| :---- | :---- | :---- | :---- |
| **Total Parameters** | 2.3B effective (5.1B with embeddings) | 4.5B effective (8B with embeddings) | 30.7B |
| **Layers** | 35 | 42 | 60 |
| **Sliding Window** | 512 tokens | 512 tokens | 1024 tokens |
| **Context Length** | 128K tokens | 128K tokens | 256K tokens |
| **Vocabulary Size** | 262K | 262K | 262K |
| **Supported Modalities** | Text, Image, Audio | Text, Image, Audio | Text, Image |
| **Vision Encoder Parameters** | *~150M* | *~150M* | *~550M* |
| **Audio Encoder Parameters** | *~300M* | *~300M* | No Audio |
The "E" in E2B and E4B stands for "effective" parameters. The smaller models incorporate Per-Layer Embeddings (PLE) to maximize parameter efficiency in on-device deployments. Rather than adding more layers or parameters to the model, PLE gives each decoder layer its own small embedding for every token. These embedding tables are large but are only used for quick lookups, which is why the effective parameter count is much smaller than the total.
### Mixture-of-Experts (MoE) Model
| Property | 26B A4B MoE |
| :---- | :---- |
| **Total Parameters** | 25.2B |
| **Active Parameters** | 3.8B |
| **Layers** | 30 |
| **Sliding Window** | 1024 tokens |
| **Context Length** | 256K tokens |
| **Vocabulary Size** | 262K |
| **Expert Count** | 8 active / 128 total and 1 shared |
| **Supported Modalities** | Text, Image |
| **Vision Encoder Parameters** | *~550M* |
The "A" in 26B A4B stands for "active parameters" in contrast to the total number of parameters the model contains. By only activating a 4B subset of parameters during inference, the Mixture-of-Experts model runs much faster than its 26B total might suggest. This makes it an excellent choice for fast inference compared to the dense 31B model since it runs almost as fast as a 4B-parameter model.
## **Benchmark Results**
These models were evaluated against a large collection of different datasets and metrics to cover different aspects of text generation. Evaluation results marked in the table are for instruction-tuned models.
| | Gemma 4 31B | Gemma 4 26B A4B | Gemma 4 E4B | Gemma 4 E2B | Gemma 3 27B (no think) |
| :---- | :---- | :---- | :---- | :---- | :---- |
| MMLU Pro | 85.2% | 82.6% | 69.4% | 60.0% | 67.6% |
| AIME 2026 no tools | 89.2% | 88.3% | 42.5% | 37.5% | 20.8% |
| LiveCodeBench v6 | 80.0% | 77.1% | 52.0% | 44.0% | 29.1% |
| Codeforces ELO | 2150 | 1718 | 940 | 633 | 110 |
| GPQA Diamond | 84.3% | 82.3% | 58.6% | 43.4% | 42.4% |
| Tau2 (average over 3) | 76.9% | 68.2% | 42.2% | 24.5% | 16.2% |
| HLE no tools | 19.5% | 8.7% | - | - | - |
| HLE with search | 26.5% | 17.2% | - | - | - |
| BigBench Extra Hard | 74.4% | 64.8% | 33.1% | 21.9% | 19.3% |
| MMMLU | 88.4% | 86.3% | 76.6% | 67.4% | 70.7% |
| **Vision** | | | | | |
| MMMU Pro | 76.9% | 73.8% | 52.6% | 44.2% | 49.7% |
| OmniDocBench 1.5 (average edit distance, lower is better) | 0.131 | 0.149 | 0.181 | 0.290 | 0.365 |
| MATH-Vision | 85.6% | 82.4% | 59.5% | 52.4% | 46.0% |
| MedXPertQA MM | 61.3% | 58.1% | 28.7% | 23.5% | - |
| **Audio** | | | | | |
| CoVoST | - | - | 35.54 | 33.47 | - |
| FLEURS (lower is better) | - | - | 0.08 | 0.09 | - |
| **Long Context** | | | | | |
| MRCR v2 8 needle 128k (average) | 66.4% | 44.1% | 25.4% | 19.1% | 13.5% |
## **Core Capabilities**
Gemma 4 models handle a broad range of tasks across text, vision, and audio. Key capabilities include:
* **Thinking** Built-in reasoning mode that lets the model think step-by-step before answering.
* **Long Context** Context windows of up to 128K tokens (E2B/E4B) and 256K tokens (26B A4B/31B).
* **Image Understanding** Object detection, Document/PDF parsing, screen and UI understanding, chart comprehension, OCR (including multilingual), handwriting recognition, and pointing. Images can be processed at variable aspect ratios and resolutions.
* **Video Understanding** Analyze video by processing sequences of frames.
* **Interleaved Multimodal Input** Freely mix text and images in any order within a single prompt.
* **Function Calling** Native support for structured tool use, enabling agentic workflows.
* **Coding** Code generation, completion, and correction.
* **Multilingual** Out-of-the-box support for 35+ languages, pre-trained on 140+ languages.
* **Audio** (E2B and E4B only) Automatic speech recognition (ASR) and speech-to-translated-text translation across multiple languages.
## Getting Started
You can use all Gemma 4 models with the latest version of Transformers. To get started, install the necessary dependencies in your environment:
`pip install -U transformers torch accelerate`
Once you have everything installed, you can proceed to load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output:
```python
# Prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a short joke about saving RAM."},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=1024)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
To enable reasoning, set `enable_thinking=True` and the `parse_response` function will take care of parsing the thinking output.
Below, you will also find snippets for processing audio (E2B and E4B only), images, and video alongside text:
<details>
<summary>Code for processing Audio</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process audio. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the audio URL in the prompt:
```python
# Prompt - add audio before text
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav"},
{"type": "text", "text": "Transcribe the following speech segment in its original language. Follow these specific instructions for formatting the answer:\n* Only output the transcription, with no newlines.\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Images</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process images. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the image URL in the prompt:
```python
# Prompt - add image before text
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png"},
{"type": "text", "text": "What is shown in this image?"}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
<details>
<summary>Code for processing Videos</summary>
Instead of using `AutoModelForCausalLM`, you can use `AutoModelForMultimodalLM` to process videos. To use it, make sure to install the following packages:
`pip install -U transformers torch torchvision librosa accelerate`
You can then load the model with the code below:
```python
from transformers import AutoProcessor, AutoModelForMultimodalLM
MODEL_ID = "google/gemma-4-E4B-it"
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto"
)
```
Once the model is loaded, you can start generating output by directly referencing the video URL in the prompt:
```python
# Prompt - add video before text
messages = [
{
'role': 'user',
'content': [
{"type": "video", "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4"},
{'type': 'text', 'text': 'Describe this video.'}
]
}
]
# Process input
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
processor.parse_response(response)
```
</details>
## **Best Practices**
For the best performance, use these configurations and best practices:
### 1. Sampling Parameters
Use the following standardized sampling configuration across all use cases:
* `temperature=1.0`
* `top_p=0.95`
* `top_k=64`
### 2. Thinking Mode Configuration
Compared to Gemma 3, the models use standard `system`, `assistant`, and `user` roles. To properly manage the thinking process, use the following control tokens:
* **Trigger Thinking:** Thinking is enabled by including the `<|think|>` token at the start of the system prompt. To disable thinking, remove the token.
* **Standard Generation:** When thinking is enabled, the model will output its internal reasoning followed by the final answer using this structure:
`<|channel>thought\n`**[Internal reasoning]**`<channel|>`
* **Disabled Thinking Behavior:** For all models except for the E2B and E4B variants, if thinking is disabled, the model will still generate the tags but with an empty thought block:
`<|channel>thought\n<channel|>`**[Final answer]**
> [!Note]
> Note that many libraries like Transformers and llama.cpp handle the complexities of the chat template for you.
### 3. Multi-Turn Conversations
* **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final response. Thoughts from previous model turns must *not be added* before the next user turn begins.
### 4. Modality order
* For optimal performance with multimodal inputs, place image and/or audio content **before** the text in your prompt.
### 5. Variable Image Resolution
Aside from variable aspect ratios, Gemma 4 supports variable image resolution through a configurable visual token budget, which controls how many tokens are used to represent an image. A higher token budget preserves more visual detail at the cost of additional compute, while a lower budget enables faster inference for tasks that don't require fine-grained understanding.
* The supported token budgets are: **70**, **140**, **280**, **560**, and **1120**.
* Use *lower budgets* for classification, captioning, or video understanding, where faster inference and processing many frames outweigh fine-grained detail.
* Use *higher budgets* for tasks like OCR, document parsing, or reading small text.
### 6. Audio
Use the following prompt structures for audio processing:
* **Audio Speech Recognition (ASR)**
```text
Transcribe the following speech segment in {LANGUAGE} into {LANGUAGE} text.
Follow these specific instructions for formatting the answer:
* Only output the transcription, with no newlines.
* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three.
```
* **Automatic Speech Translation (AST)**
```text
Transcribe the following speech segment in {SOURCE_LANGUAGE}, then translate it into {TARGET_LANGUAGE}.
When formatting the answer, first output the transcription in {SOURCE_LANGUAGE}, then one newline, then output the string '{TARGET_LANGUAGE}: ', then the translation in {TARGET_LANGUAGE}.
```
### 7. Audio and Video Length
All models support image inputs and can process videos as frames whereas the E2B and E4B models also support audio inputs. Audio supports a maximum length of 30 seconds. Video supports a maximum of 60 seconds assuming the images are processed at one frame per second.
## **Model Data**
Data used for model training and how the data was processed.
### **Training Dataset**
Our pre-training dataset is a large-scale, diverse collection of data encompassing a wide range of domains and modalities, which includes web documents, code, images, audio, with a cutoff date of January 2025. Here are the key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 140 languages.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code and understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.
* **Images**: A wide range of images enables the model to perform image analysis and visual data extraction tasks.
The combination of these diverse data sources is crucial for training a powerful multimodal model that can handle a wide variety of different tasks and data formats.
### **Data Preprocessing**
Here are the key data cleaning and filtering methods applied to the training data:
* **CSAM Filtering**: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
* **Sensitive Data Filtering**: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
* **Additional methods**: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
## **Ethics and Safety**
As open models become central to enterprise infrastructure, provenance and security are paramount. Developed by Google DeepMind, Gemma 4 undergoes the same rigorous safety evaluations as our proprietary Gemini models.
### **Evaluation Approach**
Gemma 4 models were developed in partnership with internal safety and responsible AI teams. A range of automated as well as human evaluations were conducted to help improve model safety. These evaluations align with [Googles AI principles](https://ai.google/principles/), as well as safety policies, which aim to prevent our generative AI models from generating harmful content, including:
* Content related to child sexual abuse material and exploitation
* Dangerous content (e.g., promoting suicide, or instructing in activities that could cause real-world harm)
* Sexually explicit content
* Hate speech (e.g., dehumanizing members of protected groups)
* Harassment (e.g., encouraging violence against people)
### **Evaluation Results**
For all areas of safety testing, we saw major improvements in all categories of content safety relative to previous Gemma models. Overall, Gemma 4 models significantly outperform Gemma 3 and 3n models in improving safety, while keeping unjustified refusals low. All testing was conducted without safety filters to evaluate the model capabilities and behaviors. For both text-to-text and image-to-text, and across all model sizes, the model produced minimal policy violations, and showed significant improvements over previous Gemma models' performance.
## **Usage and Limitations**
These models have certain limitations that users should be aware of.
### **Intended Usage**
Multimodal models (capable of processing vision, language, and/or audio) have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
* **Content Creation and Communication**
* **Text Generation**: These models can be used to generate creative text formats such as poems, scripts, code, marketing copy, and email drafts.
* **Chatbots and Conversational AI**: Power conversational interfaces for customer service, virtual assistants, or interactive applications.
* **Text Summarization**: Generate concise summaries of a text corpus, research papers, or reports.
* **Image Data Extraction**: These models can be used to extract, interpret, and summarize visual data for text communications.
* **Audio Processing and Interaction**: The smaller models (E2B and E4B) can analyze and interpret audio inputs, enabling voice-driven interactions and transcriptions.
* **Research and Education**
* **Natural Language Processing (NLP) and VLM Research**: These models can serve as a foundation for researchers to experiment with VLM and NLP techniques, develop algorithms, and contribute to the advancement of the field.
* **Language Learning Tools**: Support interactive language learning experiences, aiding in grammar correction or providing writing practice.
* **Knowledge Exploration**: Assist researchers in exploring large bodies of text by generating summaries or answering questions about specific topics.
### **Limitations**
* **Training Data**
* The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
* The scope of the training dataset determines the subject areas the model can handle effectively.
* **Context and Task Complexity**
* Models perform well on tasks that can be framed with clear prompts and instructions. Open-ended or highly complex tasks might be challenging.
* A model's performance can be influenced by the amount of context provided (longer context generally leads to better outputs, up to a certain point).
* **Language Ambiguity and Nuance**
* Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
* **Factual Accuracy**
* Models generate responses based on information they learned from their training datasets, but they are not knowledge bases. They may generate incorrect or outdated factual statements.
* **Common Sense**
* Models rely on statistical patterns in language. They might lack the ability to apply common sense reasoning in certain situations.
### **Ethical Considerations and Risks**
The development of vision-language models (VLMs) raises several ethical concerns. In creating an open model, we have carefully considered the following:
* **Bias and Fairness**
* VLMs trained on large-scale, real-world text and image data can reflect socio-cultural biases embedded in the training material. Gemma 4 models underwent careful scrutiny, input data pre-processing, and post-training evaluations as reported in this card to help mitigate the risk of these biases.
* **Misinformation and Misuse**
* VLMs can be misused to generate text that is false, misleading, or harmful.
* Guidelines are provided for responsible use with the model, see the [Responsible Generative AI Toolkit](https://ai.google.dev/responsible).
* **Transparency and Accountability**
* This model card summarizes details on the models' architecture, capabilities, limitations, and evaluation processes.
* A responsibly developed open model offers the opportunity to share innovation by making VLM technology accessible to developers and researchers across the AI ecosystem.
**Risks identified and mitigations**:
* **Generation of harmful content**: Mechanisms and guidelines for content safety are essential. Developers are encouraged to exercise caution and implement appropriate content safety safeguards based on their specific product policies and application use cases.
* **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of VLMs. Educational resources and reporting mechanisms for users to flag misuse are provided.
* **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
* **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
### **Benefits**
At the time of release, this family of models provides high-performance open vision-language model implementations designed from the ground up for responsible AI development compared to similarly sized models.
@@ -0,0 +1,344 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- set add_comma = false -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{ key }}:{
{%- if value['description'] -%}
description:<|"|>{{ value['description'] }}<|"|>
{%- set add_comma = true -%}
{%- endif -%}
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'] | dictsort -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<|"|>{{- req_item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
{%- if value['nullable'] %}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
nullable:true
{%- endif -%}
{%- if value['type'] | upper == 'OBJECT' -%}
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
}
{%- elif value is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value, value['required'] | default([])) -}}
}
{%- endif -%}
{%- if value['required'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
required:[
{%- for item in value['required'] | default([]) -%}
<|"|>{{- item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- endif -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
type:<|"|>{{ value['type'] | upper }}<|"|>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{%- macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<|"|>{{- item -}}<|"|>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<|"|>{{- params['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
{%- if 'response' in tool_data['function'] -%}
{%- set response_declaration = tool_data['function']['response'] -%}
,response:{
{%- if response_declaration['description'] -%}
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
{%- endif -%}
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{%- macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<|"|>' + argument + '<|"|>' -}}
{%- elif argument is boolean -%}
{{- 'true' if argument else 'false' -}}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<|"|>' + key + '<|"|>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is sequence -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- macro format_tool_response_block(tool_name, response) -%}
{{- '<|tool_response>' -}}
{%- if response is mapping -%}
{{- 'response:' + tool_name + '{' -}}
{%- for key, value in response | dictsort -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- '}' -}}
{%- else -%}
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
{%- endif -%}
{{- '<tool_response|>' -}}
{%- endmacro -%}
{%- set ns = namespace(prev_message_type=None) -%}
{%- set loop_messages = messages -%}
{{- bos_token -}}
{#- Handle System/Tool Definitions Block -#}
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
{{- '<|turn>system\n' -}}
{#- Inject Thinking token at the very top of the FIRST system turn -#}
{%- if enable_thinking is defined and enable_thinking -%}
{{- '<|think|>\n' -}}
{%- set ns.prev_message_type = 'think' -%}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{- messages[0]['content'] | trim -}}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- for tool in tools %}
{{- '<|tool>' -}}
{{- format_function_declaration(tool) | trim -}}
{{- '<tool|>' -}}
{%- endfor %}
{%- set ns.prev_message_type = 'tool' -%}
{%- endif -%}
{{- '<turn|>\n' -}}
{%- endif %}
{#- Pre-scan: find last user message index for reasoning guard -#}
{%- set ns_turn = namespace(last_user_idx=-1) -%}
{%- for i in range(loop_messages | length) -%}
{%- if loop_messages[i]['role'] == 'user' -%}
{%- set ns_turn.last_user_idx = i -%}
{%- endif -%}
{%- endfor -%}
{#- Loop through messages -#}
{%- for message in loop_messages -%}
{%- if message['role'] != 'tool' -%}
{%- set ns.prev_message_type = None -%}
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
{%- set prev_nt = namespace(role=None, found=false) -%}
{%- if loop.index0 > 0 -%}
{%- for j in range(loop.index0 - 1, -1, -1) -%}
{%- if not prev_nt.found -%}
{%- if loop_messages[j]['role'] != 'tool' -%}
{%- set prev_nt.role = loop_messages[j]['role'] -%}
{%- set prev_nt.found = true -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
{%- if not continue_same_model_turn -%}
{{- '<|turn>' + role + '\n' }}
{%- endif -%}
{#- Render reasoning/reasoning_content as thinking channel -#}
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- set function = tool_call['function'] -%}
{{- '<|tool_call>call:' + function['name'] + '{' -}}
{%- if function['arguments'] is mapping -%}
{%- set ns_args = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns_args.found_first %},{% endif -%}
{%- set ns_args.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{{- function['arguments'] -}}
{%- endif -%}
{{- '}<tool_call|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- set ns_tr_out = namespace(flag=false) -%}
{%- if message.get('tool_responses') -%}
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
{%- for tool_response in message['tool_responses'] -%}
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endfor -%}
{%- elif message.get('tool_calls') -%}
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
{%- set ns_tool_scan = namespace(stopped=false) -%}
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
{%- if ns_tool_scan.stopped -%}
{%- elif loop_messages[k]['role'] != 'tool' -%}
{%- set ns_tool_scan.stopped = true -%}
{%- else -%}
{%- set follow = loop_messages[k] -%}
{#- Resolve tool_call_id to function name -#}
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
{%- for tc in message['tool_calls'] -%}
{%- if tc.get('id') == follow.get('tool_call_id') -%}
{%- set ns_tname.name = tc['function']['name'] -%}
{%- endif -%}
{%- endfor -%}
{#- Handle content as string or content-parts array -#}
{%- set tool_body = follow.get('content') -%}
{%- if tool_body is string -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- elif tool_body is sequence and tool_body is not string -%}
{%- set ns_txt = namespace(s='') -%}
{%- for part in tool_body -%}
{%- if part.get('type') == 'text' -%}
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
{%- endif -%}
{%- endfor -%}
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
{%- else -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- endif -%}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if message['content'] is string -%}
{%- if role == 'model' -%}
{{- strip_thinking(message['content']) -}}
{%- else -%}
{{- message['content'] | trim -}}
{%- endif -%}
{%- elif message['content'] is sequence -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'text' -%}
{%- if role == 'model' -%}
{{- strip_thinking(item['text']) -}}
{%- else -%}
{{- item['text'] | trim -}}
{%- endif -%}
{%- elif item['type'] == 'image' -%}
{{- '<|image|>' -}}
{%- set ns.prev_message_type = 'image' -%}
{%- elif item['type'] == 'audio' -%}
{{- '<|audio|>' -}}
{%- set ns.prev_message_type = 'audio' -%}
{%- elif item['type'] == 'video' -%}
{{- '<|video|>' -}}
{%- set ns.prev_message_type = 'video' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
{{- '<|tool_response>' -}}
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
{{- '<turn|>\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
{{- '<|turn>model\n' -}}
{%- endif -%}
{%- endif -%}
@@ -0,0 +1,74 @@
{
"audio_token": "<|audio|>",
"backend": "tokenizers",
"boa_token": "<|audio>",
"boi_token": "<|image>",
"bos_token": "<bos>",
"eoa_token": "<audio|>",
"eoc_token": "<channel|>",
"eoi_token": "<image|>",
"eos_token": "<eos>",
"eot_token": "<turn|>",
"escape_token": "<|\"|>",
"etc_token": "<tool_call|>",
"etd_token": "<tool|>",
"etr_token": "<tool_response|>",
"extra_special_tokens": [
"<|video|>"
],
"image_token": "<|image|>",
"mask_token": "<mask>",
"model_max_length": 1000000000000000019884624838656,
"pad_token": "<pad>",
"padding_side": "left",
"processor_class": "Gemma4Processor",
"response_schema": {
"type": "object",
"properties": {
"role": {
"const": "assistant"
},
"thinking": {
"type": "string"
},
"content": {
"type": "string"
},
"tool_calls": {
"x-regex-iterator": "<\\|tool_call>(.*?)<tool_call\\|>",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"const": "function"
},
"function": {
"type": "object",
"x-regex": "call\\:(?P<name>\\w+)(?P<arguments>\\{.*\\})",
"properties": {
"name": {
"type": "string"
},
"arguments": {
"type": "object",
"x-parser": "gemma4-tool-call",
"additionalProperties": {}
}
}
}
}
}
}
},
"x-regex": "(\\<\\|channel\\>thought\\n(?P<thinking>.*?)\\<channel\\|\\>)?(?P<tool_calls>\\<\\|tool_call\\>.*\\<tool_call\\|\\>)?(?P<content>(?:(?!\\<turn\\|\\>)(?!\\<\\|tool_response\\>).)+)?(?:\\<turn\\|\\>|\\<\\|tool_response\\>)?"
},
"soc_token": "<|channel>",
"sot_token": "<|turn>",
"stc_token": "<|tool_call>",
"std_token": "<|tool>",
"str_token": "<|tool_response>",
"think_token": "<|think|>",
"tokenizer_class": "GemmaTokenizer",
"unk_token": "<unk>"
}
@@ -0,0 +1,389 @@
# Gemma4_(E2B)-Multimodal.ipynb — extracted cells
# Source: https://github.com/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb
# ===== CELL 0 (markdown) =====
# This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers.
# ===== CELL 1 (code) =====
!pip install -U transformers
# ===== CELL 2 (markdown) =====
# We can load model into `AutoModelForMultimodalLM` to make use of all capabilities.
# ===== CELL 3 (code) =====
import torch
from PIL import Image
from transformers import AutoModelForMultimodalLM, AutoProcessor
#model_list = ["google/gemma-4-26B-A4B-it", "google/gemma-4-E4B-it",
# "google/gemma-4-E2B-it", "google/gemma-4-31B-it"]
model_id = "google/gemma-4-E2B-it"
model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
# ===== CELL 4 (markdown) =====
# ## Code completion
# ===== CELL 5 (markdown) =====
# We give Gemma-4 a website screenshot to reproduce the code.
# ===== CELL 6 (code) =====
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png",
},
{"type": "text", "text": "Write HTML code for this page."},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
enable_thinking=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=4000)
# ===== CELL 7 (code) =====
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
result = processor.parse_response(generated_text)
print(result["content"])
# ===== CELL 8 (markdown) =====
# ## Video Inference
# ===== CELL 9 (markdown) =====
# We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`.
# ===== CELL 10 (code) =====
messages = [
{
"role": "user",
"content": [
{"type": "video", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"},
{"type": "text", "text": "What is happening in the video? What is the song about?"},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
load_audio_from_video=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=200)
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
result = processor.parse_response(generated_text)
# ===== CELL 11 (code) =====
print(result["content"])
# ===== CELL 12 (markdown) =====
# ## Multimodal Function Calling
# ===== CELL 13 (code) =====
import re
WEATHER_TOOL = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather for a specific location.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city name"},
},
"required": ["city"],
},
},
}
tools = [WEATHER_TOOL]
messages = [
{"role": "user", "content": [
{"type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg"},
{"type": "text", "text": "What is the city in this image? Check the weather there right now."},
]},
]
inputs = processor.apply_chat_template(
messages,
tools=[WEATHER_TOOL],
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
enable_thinking=True,
).to(model.device)
# ===== CELL 14 (code) =====
output = model.generate(**inputs, max_new_tokens=1000)
# ===== CELL 15 (code) =====
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
result = processor.parse_response(generated_text)
# ===== CELL 16 (code) =====
print(result["content"])
# ===== CELL 17 (markdown) =====
# # Any-to-any inference
# ===== CELL 18 (markdown) =====
# We can also run the model with `any-to-any` pipeline.
# ===== CELL 19 (code) =====
from transformers import pipeline
pipe = pipeline("any-to-any", model="google/gemma-4-e2b-it")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
},
{"type": "text", "text": "What is happening in this video?"},
],
}
]
# ===== CELL 20 (code) =====
pipe(messages)#, load_audio_from_video=True)
# ===== CELL 21 (code) =====
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
},
{"type": "text", "text": "What is happening in this video?"},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
# ===== CELL 22 (markdown) =====
# # Object detection and pointing
# ===== CELL 23 (code) =====
import re
import torch
from transformers.image_utils import load_image
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import json
# ===== CELL 24 (code) =====
image_url = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png"
image = load_image(image_url)
# ===== CELL 25 (code) =====
def resize_to_48_multiple(image):
w, h = image.size
new_w = (w // 48) * 48
new_h = (h // 48) * 48
return image.crop((0, 0, new_w, new_h))
# ===== CELL 26 (code) =====
def inputs_for_object_detection(image, what_object):
messages = [
{
"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": f"What's the bounding box for the {what_object} in the image?"}
]
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
enable_thinking=False,
)
return inputs.to(model.device)
# ===== CELL 27 (code) =====
def extract_json(text: str):
text = text.strip()
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
# Try direct parse first
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Fallback: extract first JSON object or array
match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL)
if match:
candidate = match.group(1)
return json.loads(candidate)
raise ValueError("No valid JSON found")
# ===== CELL 28 (code) =====
def detect_object(image_url, what_object):
image = load_image(image_url)
image = resize_to_48_multiple(image)
inputs = inputs_for_object_detection(image, what_object)
input_len = inputs["input_ids"].shape[-1]
generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
generated = processor.decode(generated_outputs[0, input_len:])
parsed_json = extract_json(generated)[0]
return parsed_json
# ===== CELL 29 (code) =====
def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):
dpi = 72
width, height = image.size
fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})
ax.imshow(image)
ymin, xmin, ymax, xmax = box
re_h, re_w = resize_shape if resize_shape is not None else (height, width)
xmin = (xmin / re_w) * width
ymin = (ymin/ re_h) * height
xmax = (xmax / re_w) * width
ymax = (ymax/ re_h) * height
w = xmax - xmin
h = ymax - ymin
rect = patches.Rectangle(
(xmin, ymin),
w,
h,
linewidth=10,
edgecolor="green",
facecolor="none"
)
ax.add_patch(rect)
if label is not None:
ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
plt.savefig(f"boxes_{i}.png")
plt.close(fig)
display(fig)
# ===== CELL 30 (code) =====
def display_detected_object(image_url, what_object):
image = load_image(image_url)
image = resize_to_48_multiple(image)
detection = detect_object(image_url, what_object)
box = detection["box_2d"]
label = detection.get("label", f"{what_object}")
draw_pascal_voc_boxes("1000", image, box, label)
# ===== CELL 31 (code) =====
display_detected_object("https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png", "bike")
# ===== CELL 32 (markdown) =====
# ## Captioning
# ===== CELL 33 (code) =====
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png"},
{"type": "text", "text": "Write single detailed caption for this image."},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=512)
input_len = inputs.input_ids.shape[-1]
generated_text_ids = output[0][input_len:]
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
result = processor.parse_response(generated_text)
print(result["content"])
# ===== CELL 34 (markdown) =====
# ## Audio Understanding
# ===== CELL 35 (code) =====
messages = [
{
"role": "user",
"content": [
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
{"type": "text", "text": "Can you describe this audio in detail?"},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=False,
)
print(processor.decode(output[0], skip_special_tokens=True))
@@ -0,0 +1,595 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"!pip install -U transformers"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can load model into `AutoModelForMultimodalLM` to make use of all capabilities."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"from PIL import Image\n",
"\n",
"from transformers import AutoModelForMultimodalLM, AutoProcessor\n",
"#model_list = [\"google/gemma-4-26B-A4B-it\", \"google/gemma-4-E4B-it\",\n",
"# \"google/gemma-4-E2B-it\", \"google/gemma-4-31B-it\"]\n",
"model_id = \"google/gemma-4-E2B-it\"\n",
"model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map=\"auto\")\n",
"processor = AutoProcessor.from_pretrained(model_id)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Code completion"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We give Gemma-4 a website screenshot to reproduce the code."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"image\",\n",
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png\",\n",
" },\n",
" {\"type\": \"text\", \"text\": \"Write HTML code for this page.\"},\n",
" ],\n",
" }\n",
"]\n",
"\n",
"inputs = processor.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" return_dict=True,\n",
" return_tensors=\"pt\",\n",
" add_generation_prompt=True,\n",
" enable_thinking=True,\n",
").to(model.device)\n",
"\n",
"output = model.generate(**inputs, max_new_tokens=4000)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_len = inputs.input_ids.shape[-1]\n",
"generated_text_ids = output[0][input_len:]\n",
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
"result = processor.parse_response(generated_text)\n",
"\n",
"print(result[\"content\"])"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Video Inference"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"video\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4\"},\n",
" {\"type\": \"text\", \"text\": \"What is happening in the video? What is the song about?\"},\n",
" ],\n",
" },\n",
"]\n",
"inputs = processor.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" return_dict=True,\n",
" return_tensors=\"pt\",\n",
" add_generation_prompt=True,\n",
" load_audio_from_video=True,\n",
").to(model.device)\n",
"output = model.generate(**inputs, max_new_tokens=200)\n",
"input_len = inputs.input_ids.shape[-1]\n",
"generated_text_ids = output[0][input_len:]\n",
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
"result = processor.parse_response(generated_text)\n"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(result[\"content\"])"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Multimodal Function Calling"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import re\n",
"\n",
"WEATHER_TOOL = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Gets the current weather for a specific location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\"type\": \"string\", \"description\": \"The city name\"},\n",
" },\n",
" \"required\": [\"city\"],\n",
" },\n",
" },\n",
"}\n",
"tools = [WEATHER_TOOL]\n",
"\n",
"messages = [\n",
" {\"role\": \"user\", \"content\": [\n",
" {\"type\": \"image\", \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg\"},\n",
" {\"type\": \"text\", \"text\": \"What is the city in this image? Check the weather there right now.\"},\n",
" ]},\n",
"]\n",
"\n",
"inputs = processor.apply_chat_template(\n",
" messages,\n",
" tools=[WEATHER_TOOL],\n",
" tokenize=True,\n",
" return_dict=True,\n",
" return_tensors=\"pt\",\n",
" add_generation_prompt=True,\n",
" enable_thinking=True,\n",
").to(model.device)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"output = model.generate(**inputs, max_new_tokens=1000)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_len = inputs.input_ids.shape[-1]\n",
"generated_text_ids = output[0][input_len:]\n",
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
"result = processor.parse_response(generated_text)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(result[\"content\"])"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Any-to-any inference"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We can also run the model with `any-to-any` pipeline."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"from transformers import pipeline\n",
"\n",
"pipe = pipeline(\"any-to-any\", model=\"google/gemma-4-e2b-it\")\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"video\",\n",
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
" },\n",
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
" ],\n",
" }\n",
"]\n"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pipe(messages)#, load_audio_from_video=True)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"video\",\n",
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
" },\n",
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
" ],\n",
" }\n",
"]\n",
"\n",
"inputs = processor.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" add_generation_prompt=True,\n",
" return_dict=True,\n",
" return_tensors=\"pt\"\n",
")\n",
"inputs = inputs.to(model.device)\n",
"\n",
"generated_ids = model.generate(**inputs, max_new_tokens=128)\n",
"generated_ids_trimmed = [\n",
" out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
"]\n",
"output_text = processor.batch_decode(\n",
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
")\n",
"print(output_text)\n"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Object detection and pointing"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import re\n",
"import torch\n",
"from transformers.image_utils import load_image\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as patches\n",
"import json"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"image_url = \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\"\n",
"image = load_image(image_url)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def resize_to_48_multiple(image):\n",
" w, h = image.size\n",
" new_w = (w // 48) * 48\n",
" new_h = (h // 48) * 48\n",
" return image.crop((0, 0, new_w, new_h))"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def inputs_for_object_detection(image, what_object):\n",
" messages = [\n",
" {\n",
" \"role\": \"user\", \"content\": [\n",
" {\"type\": \"image\", \"image\": image},\n",
" {\"type\": \"text\", \"text\": f\"What's the bounding box for the {what_object} in the image?\"}\n",
" ]\n",
" }\n",
" ]\n",
"\n",
" inputs = processor.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" add_generation_prompt=True,\n",
" return_dict=True,\n",
" return_tensors=\"pt\",\n",
" enable_thinking=False,\n",
" )\n",
"\n",
" return inputs.to(model.device)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def extract_json(text: str):\n",
" text = text.strip()\n",
"\n",
" text = re.sub(r\"^```(?:json)?\\s*\", \"\", text)\n",
" text = re.sub(r\"\\s*```$\", \"\", text)\n",
"\n",
" # Try direct parse first\n",
" try:\n",
" return json.loads(text)\n",
" except json.JSONDecodeError:\n",
" pass\n",
"\n",
" # Fallback: extract first JSON object or array\n",
" match = re.search(r'(\\{.*\\}|\\[.*\\])', text, re.DOTALL)\n",
" if match:\n",
" candidate = match.group(1)\n",
" return json.loads(candidate)\n",
"\n",
" raise ValueError(\"No valid JSON found\")"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def detect_object(image_url, what_object):\n",
" image = load_image(image_url)\n",
" image = resize_to_48_multiple(image)\n",
" inputs = inputs_for_object_detection(image, what_object)\n",
" input_len = inputs[\"input_ids\"].shape[-1]\n",
" generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)\n",
" generated = processor.decode(generated_outputs[0, input_len:])\n",
" parsed_json = extract_json(generated)[0]\n",
" return parsed_json"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):\n",
" dpi = 72\n",
" width, height = image.size\n",
" fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})\n",
"\n",
" ax.imshow(image)\n",
"\n",
" ymin, xmin, ymax, xmax = box\n",
" re_h, re_w = resize_shape if resize_shape is not None else (height, width)\n",
" xmin = (xmin / re_w) * width\n",
" ymin = (ymin/ re_h) * height\n",
" xmax = (xmax / re_w) * width\n",
" ymax = (ymax/ re_h) * height\n",
"\n",
" w = xmax - xmin\n",
" h = ymax - ymin\n",
"\n",
" rect = patches.Rectangle(\n",
" (xmin, ymin),\n",
" w,\n",
" h,\n",
" linewidth=10,\n",
" edgecolor=\"green\",\n",
" facecolor=\"none\"\n",
" )\n",
" ax.add_patch(rect)\n",
"\n",
" if label is not None:\n",
" ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor=\"yellow\", alpha=0.5))\n",
"\n",
" plt.axis(\"off\")\n",
" plt.savefig(f\"boxes_{i}.png\")\n",
" plt.close(fig)\n",
" display(fig)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def display_detected_object(image_url, what_object):\n",
" image = load_image(image_url)\n",
" image = resize_to_48_multiple(image)\n",
" detection = detect_object(image_url, what_object)\n",
" box = detection[\"box_2d\"]\n",
" label = detection.get(\"label\", f\"{what_object}\")\n",
" draw_pascal_voc_boxes(\"1000\", image, box, label)"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"display_detected_object(\"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\", \"bike\")"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"##\u00a0Captioning"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png\"},\n",
" {\"type\": \"text\", \"text\": \"Write single detailed caption for this image.\"},\n",
" ],\n",
" },\n",
"]\n",
"\n",
"inputs = processor.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" return_dict=True,\n",
" return_tensors=\"pt\",\n",
" add_generation_prompt=True,\n",
").to(model.device)\n",
"\n",
"output = model.generate(**inputs, max_new_tokens=512)\n",
"input_len = inputs.input_ids.shape[-1]\n",
"generated_text_ids = output[0][input_len:]\n",
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
"result = processor.parse_response(generated_text)\n",
"print(result[\"content\"])"
],
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Audio Understanding"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"audio\", \"url\": \"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3\"},\n",
" {\"type\": \"text\", \"text\": \"Can you describe this audio in detail?\"},\n",
" ],\n",
" },\n",
"]\n",
"\n",
"inputs = processor.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" return_dict=True,\n",
" return_tensors=\"pt\",\n",
" add_generation_prompt=True,\n",
").to(model.device)\n",
"\n",
"output = model.generate(\n",
" **inputs,\n",
" max_new_tokens=1000,\n",
" do_sample=False,\n",
")\n",
"\n",
"print(processor.decode(output[0], skip_special_tokens=True))\n"
],
"metadata": {},
"execution_count": null,
"outputs": []
}
]
}
@@ -0,0 +1,302 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl",
# "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env",
# ]
# ///
"""
GRPO training with OpenEnv's CARLA environment for VLMs (Vision Language Models).
This script uses `environment_factory` with multimodal tool responses: each tool action
returns a camera image from the vehicle alongside the text scene description, allowing the
VLM to see the driving scene visually after each action.
The CARLA environment simulates an emergency driving scenario where pedestrians are ahead
and the model must learn to observe the scene and take the correct action (e.g., swerve
to an empty lane) to minimize casualties.
Setup:
```sh
pip install "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env"
```
Usage (requires at least 2 CARLA Spaces, each supports only 1 concurrent connection):
```sh
python examples/scripts/openenv/carla_vlm.py \
--env-urls https://server1.hf.space https://server2.hf.space
```
"""
import argparse
import base64
from io import BytesIO
from carla_env import CarlaAction, CarlaEnv
from datasets import Dataset
from PIL import Image
from trl import GRPOConfig, GRPOTrainer
def parse_args():
parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.")
parser.add_argument("--model", type=str, default="google/gemma-4-E2B-it")
parser.add_argument(
"--env-urls",
type=str,
nargs="+",
required=True,
help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).",
)
parser.add_argument("--dataset-size", type=int, default=1000)
parser.add_argument("--max-completion-length", type=int, default=3072)
parser.add_argument("--per-device-train-batch-size", type=int, default=None, help="Defaults to len(env-urls).")
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--image-size", type=int, default=256, help="Resize camera images to this size. 0 to disable.")
parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.")
parser.add_argument("--use-lora", action="store_true", help="Use LoRA for memory-efficient training.")
parser.add_argument("--lora-r", type=int, default=128, help="LoRA rank.")
parser.add_argument("--lora-alpha", type=int, default=256, help="LoRA alpha.")
parser.add_argument(
"--lora-target-modules",
type=str,
default="llm-only",
help="LoRA target modules. Use 'llm-only' to skip vision encoder, 'all-linear' for all.",
)
parser.add_argument(
"--learning-rate", type=float, default=5e-6, help="Learning rate. Default 5e-6 (good for LoRA r=128)."
)
parser.add_argument("--hub-model-id", type=str, default=None)
parser.add_argument("--hub-private-repo", action="store_true", help="Make the Hub repo private.")
parser.add_argument("--run-name", type=str, default=None)
parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.")
return parser.parse_args()
SIM_TICKS = 10
def reward_func(completions, environments, **kwargs):
rewards = []
for i, (comp, env) in enumerate(zip(completions, environments, strict=False)):
# Advance to episode end to capture rubric_reward
try:
final = env._advance_until_done()
if final and final.observation.rubric_reward:
env.reward = final.observation.rubric_reward
except Exception as e:
print(f"[WARN] _advance_until_done failed for gen={i}: {e}")
r = env.reward
tools = [
msg["tool_calls"][0]["function"]["name"] for msg in comp if isinstance(msg, dict) and msg.get("tool_calls")
]
rewards.append(r)
print(f"[DEBUG reward] gen={i} tools={tools} env_reward={env.reward} total={r}")
return rewards
def main():
args = parse_args()
env_url_iter = iter(args.env_urls)
image_size = args.image_size
prompt = """\
You control an autonomous vehicle in an emergency. There are pedestrians ahead and you must \
decide what to do immediately.
You will see a camera image from the vehicle after each action. Use the visual information
along with the scene description to decide your next action.
You have the following tools available:
- `observe`: Advance time and get a new observation of the scene with a camera image.
- `emergency_stop`: Apply maximum braking to stop the vehicle.
- `lane_change(direction)`: Change lane to the left or right. Direction must be "left" or "right".
Make one tool call at a time, wait for the result, then decide your next action.
Observe the scene first, then decide the best course of action to minimize harm.
Consider all available actions - sometimes avoiding the obstacle by changing lanes \
is safer than stopping in its path."""
dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]})
class CarlaVLMEnv:
def __init__(self):
self.url = next(env_url_iter)
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
self.reward = 0.0
@staticmethod
def _describe(obs) -> str:
parts = []
parts.append(f"Speed: {obs.speed_kmh:.1f} km/h.")
if obs.nearby_actors:
for actor in obs.nearby_actors:
parts.append(f"- {actor.get('type', 'actor')} at {actor.get('distance', '?')}m")
else:
parts.append("No nearby actors detected.")
if obs.collision_detected:
parts.append(f"COLLISION detected with {obs.collided_with or 'unknown'}!")
return "\n".join(parts)
@staticmethod
def _decode_image(camera_image_b64, target_size):
"""Decode base64 JPEG image and optionally resize."""
img_bytes = base64.b64decode(camera_image_b64)
img = Image.open(BytesIO(img_bytes))
if target_size > 0:
img.thumbnail((target_size, target_size), Image.LANCZOS)
return img
def _format_multimodal(self, obs) -> list:
"""Format observation as multimodal content blocks (camera image + text)."""
content = []
if obs.camera_image is not None:
img = self._decode_image(obs.camera_image, image_size)
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": self._describe(obs)})
return content
def _advance(self, ticks: int = SIM_TICKS):
result = None
for _ in range(ticks):
result = self.client.step(CarlaAction(action_type="observe"))
if result.done:
break
return result
def _advance_until_done(self, max_ticks: int = 50):
"""Advance the simulation until the episode ends."""
result = None
for _ in range(max_ticks):
result = self.client.step(CarlaAction(action_type="observe"))
if result.done:
break
return result
def _advance_and_capture(self, ticks: int = SIM_TICKS):
"""Advance the simulation, then capture an image of the current state."""
result = self._advance(ticks)
capture_result = self.client.step(CarlaAction(action_type="capture_image"))
result.observation.camera_image = capture_result.observation.camera_image
return result
def reset(self, **kwargs) -> str | None:
for attempt in range(3):
try:
result = self.client.reset(scenario_name="trolley_micro_escape_exists")
self.reward = 0.0
return self._describe(result.observation)
except Exception as e:
if attempt == 2:
raise
print(f"[WARN] reset failed (attempt {attempt + 1}/3): {e}. Reconnecting...")
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
def observe(self) -> list:
"""
Get the current scene with a camera image and description.
Returns:
The camera image and scene description with vehicle state and nearby actors.
"""
result = self._advance_and_capture()
self.reward = result.observation.rubric_reward or 0.0
return self._format_multimodal(result.observation)
def emergency_stop(self) -> list:
"""
Apply maximum braking to stop the vehicle.
Returns:
The camera image and scene description after braking.
"""
self.client.step(CarlaAction(action_type="emergency_stop"))
result = self._advance_and_capture()
self.reward = result.observation.rubric_reward or 0.0
print(f"[DEBUG env] emergency_stop: done={result.done}, reward={self.reward}")
return self._format_multimodal(result.observation)
def lane_change(self, direction: str) -> list:
"""
Change lane to avoid obstacles.
Args:
direction: Direction to change lane, either "left" or "right".
Returns:
The camera image and scene description after changing lane.
"""
self.client.step(CarlaAction(action_type="lane_change", lane_direction=direction))
result = self._advance_and_capture()
self.reward = result.observation.rubric_reward or 0.0
print(f"[DEBUG env] lane_change({direction}): done={result.done}, reward={self.reward}")
return self._format_multimodal(result.observation)
peft_config = None
if args.use_lora:
from peft import LoraConfig
if args.lora_target_modules == "llm-only":
target_modules = "all-linear"
exclude_modules = ["vision_tower", "multi_modal_projector"]
else:
target_modules = args.lora_target_modules
exclude_modules = None
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=target_modules,
exclude_modules=exclude_modules,
task_type="CAUSAL_LM",
)
trainer = GRPOTrainer(
model=args.model,
train_dataset=dataset,
reward_funcs=reward_func,
peft_config=peft_config,
args=GRPOConfig(
chat_template_kwargs={"enable_thinking": False},
log_completions=True,
logging_steps=2,
num_completions_to_print=1,
max_completion_length=args.max_completion_length,
per_device_train_batch_size=args.per_device_train_batch_size or len(args.env_urls),
steps_per_generation=1,
num_generations=len(args.env_urls),
max_tool_calling_iterations=10,
learning_rate=args.learning_rate,
gradient_accumulation_steps=args.gradient_accumulation_steps,
max_steps=args.max_steps,
push_to_hub=args.hub_model_id is not None,
hub_model_id=args.hub_model_id,
hub_private_repo=args.hub_private_repo,
run_name=args.run_name,
report_to=args.report_to,
trackio_space_id=args.trackio_space_id,
),
environment_factory=CarlaVLMEnv,
)
trainer.train()
if __name__ == "__main__":
main()
@@ -0,0 +1,184 @@
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import random
from functools import partial
import torch
from datasets import load_dataset
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
def collate_fn(examples, processor):
messages = list()
for sample in examples:
audio = sample["audio"]["array"]
label = str(sample["text"])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant that transcribes speech accurately.",
}
],
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio},
{"type": "text", "text": "Please transcribe this audio."},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
messages.append(message)
batch = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask the tokens that we do not want to include in the loss computation
# -100 is ignored during categorical cross entropy loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.audio_token_id] = -100
labels[labels == processor.tokenizer.image_token_id] = -100
labels[labels == processor.tokenizer.boi_token_id] = -100
labels[labels == processor.tokenizer.eoi_token_id] = -100
batch["labels"] = labels
return batch
def freeze_layers(model):
for name, param in model.named_parameters():
if "attn" in name:
param.requires_grad = True
else:
param.requires_grad = False
return model
def run_inference(val_dataset, processor, model, fname):
# infer before training
val_sample = random.choice(val_dataset)
audio = val_sample["audio"]["array"]
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant that transcribes speech accurately.",
}
],
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio},
{"type": "text", "text": "Please transcribe this audio."},
],
},
]
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
generation = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(f"Audio transcription: {decoded}")
print(f"Label: {val_sample['text']}")
def main():
model_id = "google/gemma-3n-E2B-it"
processor = Gemma3nProcessor.from_pretrained(model_id)
# Load and split the dataset.
ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
train_dataset = split_ds["train"].select(range(10000))
val_dataset = split_ds["test"].select(range(100))
# create data loader
partial_collate_fn = partial(collate_fn, processor=processor)
train_dataloader = DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
pin_memory=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
)
# load the model and optimizer
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to(
"cuda", dtype=torch.bfloat16
)
run_inference(val_dataset, processor, model, "pred_before.png")
model = freeze_layers(model)
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
# Start Training
accumulation_steps = 8
for idx, batch in tqdm(enumerate(train_dataloader)):
outputs = model(**batch.to(model.device, dtype=torch.bfloat16))
loss = outputs.loss / accumulation_steps
if idx % 100 == 0:
val_loss = 0.0
with torch.no_grad():
count = 0
for val_batch in tqdm(val_dataloader, desc="Validation"):
val_loss = (
val_loss
+ model(**val_batch.to(model.device, dtype=torch.bfloat16)).loss
)
count = count + 1
val_loss = val_loss / count
print(
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
)
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
loss.backward()
if idx % 8 == 0:
optimizer.step()
optimizer.zero_grad()
if __name__ == "__main__":
main()
@@ -0,0 +1,352 @@
"""
Train Gemma-3n on various vision-language datasets including intersection-dataset.
For Gemma-3n with intersection dataset:
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
sft_vlm_gemma3n.py \
--dataset_name ariG23498/intersection-dataset \
--model_name_or_path google/gemma-3n-E2B-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3n-E2B-it-trl-sft-intersection \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear \
--attn_implementation eager
Train Gemma-3n on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
sft_vlm_gemma3n.py \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear \
--attn_implementation eager
Train Gemma-3n on the FanqingM/MMIU-Benchmark dataset (multi-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
sft_vlm_gemma3n.py \
--dataset_name FanqingM/MMIU-Benchmark \
--dataset_train_split test \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear
--attn_implementation eager
"""
import io
import os
import zipfile
import torch
from datasets import DatasetDict, load_dataset
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image
from transformers import (AutoModelForImageTextToText, AutoProcessor,
Gemma3nForConditionalGeneration)
from trl import (ModelConfig, ScriptArguments, SFTConfig, SFTTrainer,
TrlParser, get_kbit_device_map, get_quantization_config)
def my_get_peft_config(model_args: ModelConfig):
"""A version of get_peft_config that handles comma-separated target modules"""
if model_args.use_peft is False:
return None
# Import here to avoid issues if PEFT is not available
try:
from peft import LoraConfig
except ImportError:
raise ValueError(
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
"Make sure to run `pip install -U peft`."
)
# Fix the target_modules to be a list if it's a comma-separated string
target_modules = model_args.lora_target_modules
if isinstance(target_modules, str) and target_modules != "all-linear":
# Convert comma-separated string to list
target_modules = [module.strip() for module in target_modules.split(",")]
peft_config = LoraConfig(
task_type=model_args.lora_task_type,
r=model_args.lora_r,
target_modules=target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias="none",
use_rslora=model_args.use_rslora,
use_dora=model_args.use_dora,
modules_to_save=model_args.lora_modules_to_save,
)
return peft_config
# For intersection dataset processing
def format_intersection_data(samples: dict) -> dict[str, list]:
"""Format intersection dataset to match expected message format"""
formatted_samples = {"messages": []}
for idx in range(len(samples["image"])):
image = samples["image"][idx].convert("RGB")
label = str(samples["label"][idx])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
formatted_samples["messages"].append(message)
return formatted_samples
# For multi-image example
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
if "image" in element:
image = element["image"]
else:
image = element
if image is not None:
# Handle dictionary with bytes
if isinstance(image, dict) and "bytes" in image:
pil_image = Image.open(io.BytesIO(image["bytes"]))
image_inputs.append(pil_image.convert("RGB"))
# Handle PIL Image objects
elif hasattr(image, "convert"):
image_inputs.append(image.convert("RGB"))
return image_inputs
def format_data(samples: dict) -> dict[str, list]:
formatted_samples = {"messages": []}
for cont in range(len(samples["question"])):
images = []
for img_path in samples["input_image_path"][cont]:
try:
with open(img_path, "rb") as f:
img_bytes = f.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
images.append({"type": "image", "image": image})
except Exception as e:
print(f"Error processing image {img_path}: {e}")
continue
formatted_samples["messages"].append(
[
{
"role": "system",
"content": [{"type": "text", "text": samples["context"][cont]}],
},
{
"role": "user",
"content": images
+ [{"type": "text", "text": samples["question"][cont]}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": samples["output"][cont]}],
},
]
)
return formatted_samples
# For multi-image example
def prepare_dataset(
dataset: DatasetDict, dataset_name: str, dataset_train_split: str
) -> DatasetDict:
all_files = list_repo_files(dataset_name, repo_type="dataset")
zip_files = [f for f in all_files if f.endswith(".zip")]
for zip_filename in zip_files:
zip_path = hf_hub_download(
repo_id=dataset_name, filename=zip_filename, repo_type="dataset"
)
extract_folder = zip_filename.replace(".zip", "")
os.makedirs(extract_folder, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
return dataset
def main():
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
################
# Model, Tokenizer & Processor
################
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
processor.tokenizer.padding_side = "right"
# Use appropriate model class based on model name
if "gemma-3n" in model_args.model_name_or_path.lower():
model = Gemma3nForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
def collate_fn(examples):
texts = []
images_list = []
for example in examples:
# Apply chat template to get text
text = processor.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
).strip()
texts.append(text)
# Extract images
if "images" in example: # single-image case
images = [img.convert("RGB") for img in example["images"]]
else: # multi-image case or intersection dataset
images = process_vision_info(example["messages"])
images_list.append(images)
# Tokenize the texts and process the images
batch = processor(
text=texts, images=images_list, return_tensors="pt", padding=True
)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
# Mask tokens for Gemma3n model
if "gemma-3n" in model_args.model_name_or_path.lower():
# Use Gemma3n specific token masking
labels[labels == processor.tokenizer.pad_token_id] = -100
if hasattr(processor.tokenizer, "image_token_id"):
labels[labels == processor.tokenizer.image_token_id] = -100
if hasattr(processor.tokenizer, "boi_token_id"):
labels[labels == processor.tokenizer.boi_token_id] = -100
if hasattr(processor.tokenizer, "eoi_token_id"):
labels[labels == processor.tokenizer.eoi_token_id] = -100
else:
# Original masking for other models
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.special_tokens_map["boi_token"]
)
]
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Handle different dataset formats
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
dataset = prepare_dataset(
dataset, script_args.dataset_name, script_args.dataset_train_split
)
elif script_args.dataset_name == "ariG23498/intersection-dataset":
# Format intersection dataset
dataset = dataset.map(
format_intersection_data, batched=True, batch_size=4, num_proc=4
)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None,
processing_class=processor.tokenizer,
peft_config=my_get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)
if __name__ == "__main__":
main()
@@ -0,0 +1,186 @@
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import random
from functools import partial
import torch
from datasets import load_dataset
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
def collate_fn(examples, processor):
messages = list()
for sample in examples:
image = sample["image"].convert("RGB")
label = str(sample["label"])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
messages.append(message)
batch = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask the tokens that we do not want to include in the loss computation
# -100 is ignored during categorical cross entropy loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.image_token_id] = -100
labels[labels == processor.tokenizer.boi_token_id] = -100
labels[labels == processor.tokenizer.eoi_token_id] = -100
batch["labels"] = labels
return batch
def freeze_layers(model):
for name, param in model.named_parameters():
if "attn" in name:
param.requires_grad = True
else:
param.requires_grad = False
return model
def run_inference(val_dataset, processor, model, fname):
# infer before training
val_sample = random.choice(val_dataset)
image = val_sample["image"].convert("RGB")
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
]
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
plt.imshow(image)
plt.axis("off")
plt.title(f"Pred: {decoded}")
plt.show()
plt.savefig(f"outputs_fine_tune/{fname}")
def main():
model_id = "google/gemma-3n-E2B-it"
processor = Gemma3nProcessor.from_pretrained(model_id)
# load the dataset
dataset_id = "ariG23498/intersection-dataset"
train_dataset = load_dataset(dataset_id, split="train")
val_dataset = load_dataset(dataset_id, split="validation")
# create data loader
partial_collate_fn = partial(collate_fn, processor=processor)
train_dataloader = DataLoader(
train_dataset,
batch_size=2,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
pin_memory=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=2,
shuffle=False,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
)
# load the model and optimizer
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda")
run_inference(val_dataset, processor, model, "pred_before.png")
model = freeze_layers(model)
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
# Start Training
accumulation_steps = 8
for idx, batch in tqdm(enumerate(train_dataloader)):
outputs = model(**batch.to(model.device))
loss = outputs.loss / accumulation_steps
if idx % 50 == 0:
val_loss = 0.0
with torch.no_grad():
count = 0
for val_batch in val_dataloader:
val_loss = val_loss + model(**val_batch.to(model.device)).loss
count = count + 1
val_loss = val_loss / count
print(
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
)
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
loss.backward()
if idx % 8 == 0:
optimizer.step()
optimizer.zero_grad()
if __name__ == "__main__":
main()
@@ -0,0 +1,425 @@
# -*- coding: utf-8 -*-
"""Gemma3n Fine-tuning on All Modalities.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1iEZUJuvKJpGU8t50BqfkiCQmGkaR6gd4
# Fine-tune Gemma3n on FineVideo
In this notebook, we will see how to fine-tune Gemma3n an videos with audios inside.
Using all three modalities is very costly compute-wise, so keep in mind that this is an educational tutorial to fit the model in 40GB VRAM.
"""
!pip install -U -q timm transformers trl peft datasets
import io
import os
import zipfile
import torch
from datasets import load_dataset
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from trl import (
SFTConfig,
SFTTrainer,
)
"""## Download videos and preprocessing
FineVideo is a quite large dataset, we don't need a ton of examples, so we stream the dataset, check the duration and download the videos shorter than 30 secs.
"""
from datasets import load_dataset
import json
import os
dataset = load_dataset("HuggingFaceFV/finevideo", split="train", streaming=True)
os.makedirs("videos", exist_ok=True)
os.makedirs("metadata", exist_ok=True)
for idx, sample in enumerate(dataset):
data = sample["json"]
duration = data.get("duration_seconds", 0)
if duration < 30:
video_filename = f"videos/sample_{idx}.mp4"
with open(video_filename, 'wb') as video_file:
video_file.write(sample['mp4'])
json_filename = f"metadata/sample_{idx}.json"
with open(json_filename, 'w') as json_file:
json.dump(sample['json'], json_file)
print(f"Number of items in content/videos: {len(os.listdir('videos'))}")
"""In FineVideo some frames are dark so we downsample 6 frames and if we can't get meaningful videos we remove them."""
import cv2
from PIL import Image
import numpy as np
def is_dark(frame, threshold=10):
return np.max(frame) < threshold # all pixels are very close to 0
def downsample_video(video_path):
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
# Generate 8 evenly spaced indices, skip first and last
full_indices = np.linspace(0, total_frames - 1, 8, dtype=int)[1:-1]
for i in full_indices:
found_valid = False
for offset in [0, -1, 1, -2, 2]: # Try nearby frames if original is dark
candidate_idx = i + offset
if 0 <= candidate_idx < total_frames:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, candidate_idx)
success, image = vidcap.read()
if success:
if not is_dark(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(candidate_idx / fps, 2)
frames.append((pil_image, timestamp))
found_valid = True
break
if not found_valid:
print(f"Warning: Could not find non-dark frame near index {i}")
vidcap.release()
# If still fewer than 8, try to top off by scanning more frames
if len(frames) < 6:
print("Trying to top off with additional non-dark frames...")
idx = 0
while len(frames) < 8 and idx < total_frames:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, idx)
success, image = vidcap.read()
if success and not is_dark(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(idx / fps, 2)
# Avoid adding duplicate timestamps
if not any(ts == timestamp for _, ts in frames):
frames.append((pil_image, timestamp))
idx += 1
return frames[:8] # Ensure exactly 8 frames
import os
import glob
def remove_dark_videos(video_dir, metadata_dir, audio_dir):
"""
Remove videos (and their metadata/audio files) if all frames are dark.
"""
video_paths = glob.glob(os.path.join(video_dir, "*.mp4"))
for video_path in video_paths:
filename = os.path.basename(video_path)
base_name = os.path.splitext(filename)[0]
frames = downsample_video(video_path)
if len(frames) < 6:
try:
os.remove(video_path)
print(f"Deleted: {video_path}")
except Exception as e:
print(f"Failed to delete {video_path}: {e}")
metadata_path = os.path.join(metadata_dir, f"{base_name}.json")
if os.path.exists(metadata_path):
os.remove(metadata_path)
# Remove audio
audio_path = os.path.join(audio_dir, f"{base_name}.wav")
if os.path.exists(audio_path):
os.remove(audio_path)
remove_dark_videos(
video_dir="videos",
metadata_dir="metadata",
audio_dir="audios"
)
"""Gemma-3n accepts video (image frames) and audio separately, so we strip audio from video."""
import os
import subprocess
video_dir = "videos"
audio_dir = "audios"
os.makedirs(audio_dir, exist_ok=True)
for filename in os.listdir(video_dir):
if not filename.endswith(".mp4"):
continue
idx = filename.split("_")[1].split(".")[0]
video_path = os.path.join(video_dir, filename)
audio_path = os.path.join(audio_dir, f"sample_{idx}.wav")
subprocess.run([
"ffmpeg", "-i", video_path,
"-q:a", "0", "-map", "a",
audio_path,
"-y"
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
"""Construct a new dataset with audio, video, metadata (video categories). This dataset is very cool, it has some questions and answers, captions and more so get creative if you have the GPU VRAM to do so. Here we solve an easier task for educational purposes."""
from datasets import Dataset
import json
def gen():
meta_dir = "metadata"
for filename in os.listdir(meta_dir):
if not filename.endswith(".json"):
continue
idx = filename.split("_")[1].split(".")[0]
if os.path.exists(f"videos/sample_{idx}.mp4"):
video_filename = f"sample_{idx}.mp4"
audio_filename = f"sample_{idx}.wav"
json_path = os.path.join(meta_dir, filename)
with open(json_path, "r") as f:
metadata = json.load(f)
yield {
"video": video_filename,
"audio": audio_filename,
"content_parent_category": metadata["content_parent_category"],
"sample_index": int(idx)
}
else:
pass
dataset = Dataset.from_generator(gen)
"""We will speed-up and downsample the audios to save space during training."""
import torchaudio
from torchaudio.transforms import Resample
import os
import torch
def preprocess_audio(audio_path, target_sample_rate=16000, max_duration_sec=5, speedup_factor=1.25):
waveform, sample_rate = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sample_rate != target_sample_rate:
resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
sample_rate = target_sample_rate
if speedup_factor > 1.0:
indices = torch.arange(0, waveform.shape[1], step=speedup_factor).long()
if indices[-1] >= waveform.shape[1]:
indices = indices[:-1]
waveform = waveform[:, indices]
max_length = int(target_sample_rate * max_duration_sec)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
torchaudio.save(audio_path, waveform, sample_rate)
for file_name in os.listdir("audios"):
if file_name.lower().endswith(".wav"):
audio_path = os.path.join("audios", file_name)
preprocess_audio(audio_path)
dataset = dataset.train_test_split(test_size=0.10, seed=42)
"""### Load the model
Make sure you have your Hugging Face token in your Colab secrets.
"""
model = Gemma3nForConditionalGeneration.from_pretrained(
"google/gemma-3n-E2B-it", torch_dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(
"google/gemma-3n-E2B-it",
)
processor.tokenizer.padding_side = "right"
processor.tokenizer.all_special_ids
"""Write our dataset collator. We will train model to predict category of a video (which can be done easily). You can do much better things, for instance FineVideo has QnA section, you can train this model to do open-ended QnA if you have a big VRAM and a lot of patience. Open-ended tasks are harder to work with, and this notebook carries educational purposes on feeding different modalities.
In collator we also downsample videos to 6 frames, we have written the helper above. For better results you need more frames.
"""
def collate_fn(examples):
video_path = examples[0]["video"]
audio_path = examples[0]["audio"]
sample_idx = filename.split("_")[1].split(".")[0]
frames = downsample_video(f"videos/{video_path}")
text = "Based on the video, predict the category of it."
message = [
{
"role": "user",
"content": [
{"type": "text", "text": text}
],
},
]
# this is how video inference should be formatted in Gemma3n
for frame in frames:
image, timestamp = frame
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
timestamp = str(timestamp).replace(".", "_")
image.save(f"image_idx_{sample_idx}_{timestamp}.png")
message[0]["content"].append({"type": "image", "url": f"image_idx_{sample_idx}_{timestamp}.png"})
message[0]["content"].append({"type": "audio", "audio": f"audios/{audio_path}"})
message.append({"role": "assistant", "content": [{"type": "text", "text": examples[0]["content_parent_category"]}]})
inputs = processor.apply_chat_template(
message,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
).to(model.device)
labels = inputs["input_ids"].clone()
special_token_ids = processor.tokenizer.all_special_ids
special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device)
mask = torch.isin(labels, special_token_ids_tensor)
labels[mask] = -100
inputs["labels"] = labels
if torch.all(inputs["pixel_values"] == 0):
print("Frames are dark")
return inputs
"""## Training
We do LoRA fine-tuning again to save up on space.
"""
from peft import LoraConfig
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
target_modules="all-linear",
lora_alpha=32,
lora_dropout=0.05,
bias="none",
use_rslora=False,
use_dora=False,
modules_to_save=None
)
model.gradient_checkpointing_disable()
model.config.use_cache = False
training_args = SFTConfig(
output_dir="/content/gemma-3n-finevideo",
eval_strategy='epoch',
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=4,
gradient_checkpointing=False,
learning_rate=1e-05,
num_train_epochs=3.0,
logging_steps=10,
save_steps=100,
bf16=True,
report_to=["tensorboard"],
dataset_kwargs={'skip_prepare_dataset': True},
remove_unused_columns=False,
max_seq_length=None,
push_to_hub=True,
dataloader_pin_memory=False,
)
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset["train"],
eval_dataset=dataset["test"] if training_args.eval_strategy != "no" else None,
processing_class=processor.tokenizer,
peft_config=peft_config,
)
trainer.train()
"""Test the model with a video of snowboarding."""
!wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4
model = trainer.model # trainer has the adapter
"""Strip audio and downsample video."""
audio_path = "/content/test_audio.wav"
subprocess.run([
"ffmpeg", "-i", "/content/IMG_8137.mp4",
"-q:a", "0", "-map", "a",
f"{audio_path}",
"-y"
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
frames = downsample_video("/content/IMG_8137.mp4")
# repeat the chat template
text = "Based on the video, predict the category of it."
message = [
{
"role": "user",
"content": [
{"type": "text", "text": text}
],
},
]
for frame in frames:
image, timestamp = frame
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
timestamp = str(timestamp).replace(".", "_")
image.save(f"test_frame_{timestamp}.png")
message[0]["content"].append({"type": "image", "url": f"test_frame_{timestamp}.png"})
message[0]["content"].append({"type": "audio", "audio": f"{audio_path}"})
message
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
).to(model.device).to(model.dtype)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
"""Thanks a lot for reading! Keep training the model further with more data or unfreeze the layers for better performance 💗"""
@@ -0,0 +1,13 @@
---
title: Gemma 4 31B It
emoji: 🚀
colorFrom: blue
colorTo: green
sdk: gradio
sdk_version: 6.12.0
python_version: "3.12.12"
app_file: app.py
pinned: false
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
@@ -0,0 +1,303 @@
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature
from transformers.generation.streamers import TextIteratorStreamer
MODEL_ID = "google/gemma-4-31b-it"
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
model = AutoModelForMultimodalLM.from_pretrained(MODEL_ID, device_map="auto", dtype=torch.bfloat16)
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
VIDEO_FILE_TYPES = (".mp4", ".mov", ".avi", ".webm")
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
THINKING_START = "<|channel>"
THINKING_END = "<channel|>"
# Special tokens to strip from decoded output (keeping thinking delimiters
# so that Gradio's reasoning_tags can find them on the frontend).
_KEEP_TOKENS = {THINKING_START, THINKING_END}
_STRIP_TOKENS = sorted(
(t for t in processor.tokenizer.all_special_tokens if t not in _KEEP_TOKENS),
key=len,
reverse=True, # longest first to avoid partial matches
)
def _strip_special_tokens(text: str) -> str:
for tok in _STRIP_TOKENS:
text = text.replace(tok, "")
return text
def _classify_file(path: str) -> str | None:
"""Return media type string for a file path, or None if unsupported."""
lower = path.lower()
if lower.endswith(IMAGE_FILE_TYPES):
return "image"
if lower.endswith(VIDEO_FILE_TYPES):
return "video"
return None
def process_new_user_message(message: dict) -> list[dict]:
"""Build content list from the new user message with URL-based media references."""
content: list[dict] = []
for path in message.get("files", []):
kind = _classify_file(path)
if kind:
content.append({"type": kind, "url": path})
content.append({"type": "text", "text": message.get("text", "")})
return content
def process_history(history: list[dict]) -> list[dict]:
"""Walk Gradio 6 history and build message list with URL-based media references."""
messages: list[dict] = []
for item in history:
if item["role"] == "assistant":
text_parts = [p["text"] for p in item["content"] if p.get("type") == "text"]
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": " ".join(text_parts)}],
}
)
else:
user_content: list[dict] = []
for part in item["content"]:
if part.get("type") == "text":
user_content.append({"type": "text", "text": part["text"]})
elif part.get("type") == "file":
filepath = part["file"]["path"]
kind = _classify_file(filepath)
if kind:
user_content.append({"type": kind, "url": filepath})
if user_content:
messages.append({"role": "user", "content": user_content})
return messages
@spaces.GPU(duration=180)
@torch.inference_mode()
def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(
processor,
timeout=30.0,
skip_prompt=True,
skip_special_tokens=not thinking,
)
generate_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"disable_compile": True,
}
exception_holder: list[Exception] = []
def _generate() -> None:
try:
model.generate(**generate_kwargs)
except Exception as e: # noqa: BLE001
exception_holder.append(e)
thread = Thread(target=_generate)
thread.start()
chunks: list[str] = []
for text in streamer:
chunks.append(text)
accumulated = "".join(chunks)
if thinking:
yield _strip_special_tokens(accumulated)
else:
yield accumulated
thread.join()
if exception_holder:
msg = f"Generation failed: {exception_holder[0]}"
raise gr.Error(msg)
def validate_input(message: dict) -> dict:
has_text = bool(message.get("text", "").strip())
has_files = bool(message.get("files"))
if not (has_text or has_files):
return gr.validate(has_text, "Please enter a message or upload a file.")
files = message.get("files", [])
kinds = [_classify_file(f) for f in files]
kinds = [k for k in kinds if k is not None]
unique_kinds = set(kinds)
if len(unique_kinds) > 1:
return gr.validate(False, "Please upload only one type of media (images or video) at a time.")
if kinds.count("video") > 1:
return gr.validate(False, "Only one video file can be uploaded at a time.")
return gr.validate(True, "")
def _has_media_type(messages: list[dict], media_type: str) -> bool:
"""Check if any message contains a content entry of the given media type."""
return any(
c.get("type") == media_type for m in messages for c in (m["content"] if isinstance(m["content"], list) else [])
)
def generate(
message: dict,
history: list[dict],
thinking: bool = False,
max_new_tokens: int = 1024,
max_soft_tokens: int = 280,
system_prompt: str = "",
) -> Iterator[str]:
messages: list[dict] = []
if system_prompt:
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
messages.extend(process_history(history))
messages.append({"role": "user", "content": process_new_user_message(message)})
template_kwargs: dict = {
"tokenize": True,
"return_dict": True,
"return_tensors": "pt",
"add_generation_prompt": True,
"processor_kwargs": {"images_kwargs": {"max_soft_tokens": max_soft_tokens}},
}
if _has_media_type(messages, "video"):
template_kwargs["load_audio_from_video"] = False
if thinking:
template_kwargs["enable_thinking"] = True
inputs = processor.apply_chat_template(messages, **template_kwargs)
n_tokens = inputs["input_ids"].shape[1]
if n_tokens > MAX_INPUT_TOKENS:
msg = f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
raise gr.Error(msg)
yield from _generate_on_gpu(inputs=inputs, max_new_tokens=max_new_tokens, thinking=thinking)
examples = [
# --- Text-only examples ---
[
{
"text": "What is the capital of France?",
"files": [],
}
],
[
{
"text": "What is the water formula?",
"files": [],
}
],
[
{
"text": "Explain quantum entanglement in simple terms.",
"files": [],
}
],
[
{
"text": "I want to do a car wash that is 50 meters away, should I walk or drive?",
"files": [],
}
],
[
{
"text": "Write a poem about beer with 4 stanzas. Format the title as an H2 markdown heading and bold the first line of each stanza.",
"files": [],
}
],
# --- Single-image examples ---
[
{
"text": "Describe this image.",
"files": ["https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg"],
}
],
[
{
"text": "What is the city in this image? Describe what you see.",
"files": ["https://imgmd.net/images/v1/guia/1698673/rio-de-janeiro-4-c.jpg"],
}
],
# --- Multi-image examples ---
[
{
"text": "What are the key similarities between these three images?",
"files": [
"https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg",
"https://ogimg.infoglobo.com.br/in/12547538-502-0e0/FT1086A/94-8705-14.jpg",
"https://amazonasatual.com.br/wp-content/uploads/2021/01/Pele.jpg",
],
}
],
# --- Video examples ---
[
{
"text": "What is happening in this video?",
"files": ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"],
}
],
]
demo = gr.ChatInterface(
fn=generate,
validator=validate_input,
chatbot=gr.Chatbot(
scale=1,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
],
reasoning_tags=[(THINKING_START, THINKING_END)],
),
textbox=gr.MultimodalTextbox(
sources=["upload"],
file_types=[*IMAGE_FILE_TYPES, *VIDEO_FILE_TYPES],
file_count="multiple",
autofocus=True,
),
multimodal=True,
additional_inputs=[
gr.Checkbox(label="Thinking", value=False),
gr.Slider(label="Max New Tokens", minimum=100, maximum=4000, step=10, value=2000),
gr.Dropdown(
label="Image Token Budget",
info="Higher values preserve more visual detail (useful for OCR/documents). Lower values are faster.",
choices=[70, 140, 280, 560, 1120],
value=280,
),
gr.Textbox(label="System Prompt", value=""),
],
additional_inputs_accordion=gr.Accordion("Settings", open=True),
stop_btn=False,
title="Gemma 4 31B It",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch(css_paths="style.css", max_file_size="20mb")
@@ -0,0 +1,362 @@
# This file was autogenerated by uv via the following command:
# uv export --no-hashes --no-dev --group hf-spaces --no-emit-package typer-slim --no-emit-package spaces -o requirements.txt
accelerate==1.13.0
# via gemma-4-31b-it
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.5
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-doc==0.0.4
# via
# fastapi
# typer
annotated-types==0.7.0
# via pydantic
anyio==4.13.0
# via
# gradio
# httpx
# mcp
# sse-starlette
# starlette
attrs==26.1.0
# via
# aiohttp
# jsonschema
# referencing
audioop-lts==0.2.2 ; python_full_version >= '3.13'
# via gradio
brotli==1.2.0
# via gradio
certifi==2026.2.25
# via
# httpcore
# httpx
# requests
cffi==2.0.0 ; platform_python_implementation != 'PyPy'
# via cryptography
charset-normalizer==3.4.7
# via requests
click==8.3.2
# via
# typer
# uvicorn
colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.7
# via pyjwt
datasets==4.8.4
dill==0.4.1
# via
# datasets
# multiprocess
fastapi==0.136.0
# via gradio
filelock==3.28.0
# via
# datasets
# huggingface-hub
# torch
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec==2026.2.0
# via
# datasets
# gradio-client
# huggingface-hub
# torch
gradio==6.12.0
# via
# gemma-4-31b-it
# spaces
gradio-client==2.4.1
# via
# gradio
# hf-gradio
groovy==0.1.2
# via gradio
h11==0.16.0
# via
# httpcore
# uvicorn
hf-gradio==0.4.0
# via gradio
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
# via huggingface-hub
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via
# datasets
# gradio
# gradio-client
# huggingface-hub
# mcp
# safehttpx
# spaces
httpx-sse==0.4.3
# via mcp
huggingface-hub==1.11.0
# via
# accelerate
# datasets
# gradio
# gradio-client
# tokenizers
# transformers
idna==3.11
# via
# anyio
# httpx
# requests
# yarl
jinja2==3.1.6
# via
# gradio
# torch
jsonschema==4.26.0
# via mcp
jsonschema-specifications==2025.9.1
# via jsonschema
markdown-it-py==4.0.0
# via rich
markupsafe==3.0.3
# via
# gradio
# jinja2
mcp==1.27.0
# via gradio
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
multidict==6.7.1
# via
# aiohttp
# yarl
multiprocess==0.70.19
# via datasets
networkx==3.6.1
# via torch
numpy==2.4.4
# via
# accelerate
# datasets
# gradio
# pandas
# torchvision
# transformers
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
orjson==3.11.8
# via gradio
packaging==26.1
# via
# accelerate
# datasets
# gradio
# gradio-client
# huggingface-hub
# spaces
# transformers
pandas==3.0.2
# via
# datasets
# gradio
pillow==12.2.0
# via
# gradio
# torchvision
propcache==0.4.1
# via
# aiohttp
# yarl
psutil==5.9.8
# via
# accelerate
# spaces
pyarrow==23.0.1
# via datasets
pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
# via cffi
pydantic==2.12.5
# via
# fastapi
# gradio
# mcp
# pydantic-settings
# spaces
pydantic-core==2.41.5
# via pydantic
pydantic-settings==2.13.1
# via mcp
pydub==0.25.1
# via gradio
pygments==2.20.0
# via rich
pyjwt==2.12.1
# via mcp
python-dateutil==2.9.0.post0
# via pandas
python-dotenv==1.2.2
# via pydantic-settings
python-multipart==0.0.26
# via
# gradio
# mcp
pytz==2026.1.post1
# via gradio
pywin32==311 ; sys_platform == 'win32'
# via mcp
pyyaml==6.0.3
# via
# accelerate
# datasets
# gradio
# huggingface-hub
# transformers
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2026.4.4
# via transformers
requests==2.33.1
# via
# datasets
# spaces
rich==15.0.0
# via typer
rpds-py==0.30.0
# via
# jsonschema
# referencing
safehttpx==0.1.7
# via gradio
safetensors==0.7.0
# via
# accelerate
# transformers
semantic-version==2.10.0
# via gradio
setuptools==82.0.1
# via torch
shellingham==1.5.4
# via typer
six==1.17.0
# via python-dateutil
sse-starlette==3.3.4
# via mcp
starlette==1.0.0
# via
# fastapi
# gradio
# mcp
# sse-starlette
sympy==1.14.0
# via torch
tokenizers==0.22.2
# via transformers
tomlkit==0.14.0
# via gradio
torch==2.9.1
# via
# accelerate
# gemma-4-31b-it
# torchvision
torchcodec==0.9.1
# via gemma-4-31b-it
torchvision==0.24.1
# via gemma-4-31b-it
tqdm==4.67.3
# via
# datasets
# huggingface-hub
# transformers
transformers==5.5.4
# via gemma-4-31b-it
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
typer==0.24.1
# via
# gradio
# hf-gradio
# huggingface-hub
# transformers
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# fastapi
# gradio
# gradio-client
# huggingface-hub
# mcp
# pydantic
# pydantic-core
# referencing
# spaces
# starlette
# torch
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings
tzdata==2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
# via pandas
urllib3==2.6.3
# via requests
uvicorn==0.44.0
# via
# gradio
# mcp
xxhash==3.6.0
# via datasets
yarl==1.23.0
# via aiohttp
@@ -0,0 +1,13 @@
---
title: Gemma 4 E4B It
emoji: 🚀
colorFrom: blue
colorTo: green
sdk: gradio
sdk_version: 6.12.0
python_version: "3.12.12"
app_file: app.py
pinned: false
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
@@ -0,0 +1,322 @@
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature
from transformers.generation.streamers import TextIteratorStreamer
MODEL_ID = "google/gemma-4-e4b-it"
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
model = AutoModelForMultimodalLM.from_pretrained(MODEL_ID, device_map="auto", dtype=torch.bfloat16)
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
AUDIO_FILE_TYPES = (".wav", ".mp3", ".flac", ".ogg")
VIDEO_FILE_TYPES = (".mp4", ".mov", ".avi", ".webm")
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
THINKING_START = "<|channel>"
THINKING_END = "<channel|>"
# Special tokens to strip from decoded output (keeping thinking delimiters
# so that Gradio's reasoning_tags can find them on the frontend).
_KEEP_TOKENS = {THINKING_START, THINKING_END}
_STRIP_TOKENS = sorted(
(t for t in processor.tokenizer.all_special_tokens if t not in _KEEP_TOKENS),
key=len,
reverse=True, # longest first to avoid partial matches
)
def _strip_special_tokens(text: str) -> str:
for tok in _STRIP_TOKENS:
text = text.replace(tok, "")
return text
def _classify_file(path: str) -> str | None:
"""Return media type string for a file path, or None if unsupported."""
lower = path.lower()
if lower.endswith(IMAGE_FILE_TYPES):
return "image"
if lower.endswith(AUDIO_FILE_TYPES):
return "audio"
if lower.endswith(VIDEO_FILE_TYPES):
return "video"
return None
def process_new_user_message(message: dict) -> list[dict]:
"""Build content list from the new user message with URL-based media references."""
content: list[dict] = []
for path in message.get("files", []):
kind = _classify_file(path)
if kind:
content.append({"type": kind, "url": path})
content.append({"type": "text", "text": message.get("text", "")})
return content
def process_history(history: list[dict]) -> list[dict]:
"""Walk Gradio 6 history and build message list with URL-based media references."""
messages: list[dict] = []
for item in history:
if item["role"] == "assistant":
text_parts = [p["text"] for p in item["content"] if p.get("type") == "text"]
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": " ".join(text_parts)}],
}
)
else:
user_content: list[dict] = []
for part in item["content"]:
if part.get("type") == "text":
user_content.append({"type": "text", "text": part["text"]})
elif part.get("type") == "file":
filepath = part["file"]["path"]
kind = _classify_file(filepath)
if kind:
user_content.append({"type": kind, "url": filepath})
if user_content:
messages.append({"role": "user", "content": user_content})
return messages
@spaces.GPU(duration=120)
@torch.inference_mode()
def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(
processor,
timeout=30.0,
skip_prompt=True,
skip_special_tokens=not thinking,
)
generate_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"disable_compile": True,
}
exception_holder: list[Exception] = []
def _generate() -> None:
try:
model.generate(**generate_kwargs)
except Exception as e: # noqa: BLE001
exception_holder.append(e)
thread = Thread(target=_generate)
thread.start()
chunks: list[str] = []
for text in streamer:
chunks.append(text)
accumulated = "".join(chunks)
if thinking:
yield _strip_special_tokens(accumulated)
else:
yield accumulated
thread.join()
if exception_holder:
msg = f"Generation failed: {exception_holder[0]}"
raise gr.Error(msg)
# FBT003 is suppressed below: gr.validate API takes bool as first positional arg.
def validate_input(message: dict) -> dict:
has_text = bool(message.get("text", "").strip())
has_files = bool(message.get("files"))
if not (has_text or has_files):
return gr.validate(False, "Please enter a message or upload a file.") # noqa: FBT003
files = message.get("files", [])
kinds = [_classify_file(f) for f in files]
kinds = [k for k in kinds if k is not None]
unique_kinds = set(kinds)
if len(unique_kinds) > 1:
return gr.validate(False, "Please upload only one type of media (images, audio, or video) at a time.") # noqa: FBT003
if kinds.count("audio") > 1:
return gr.validate(False, "Only one audio file can be uploaded at a time.") # noqa: FBT003
if kinds.count("video") > 1:
return gr.validate(False, "Only one video file can be uploaded at a time.") # noqa: FBT003
return gr.validate(True, "") # noqa: FBT003
def _has_media_type(messages: list[dict], media_type: str) -> bool:
"""Check if any message contains a content entry of the given media type."""
return any(c.get("type") == media_type for m in messages for c in m["content"])
def generate(
message: dict,
history: list[dict],
thinking: bool = False,
max_new_tokens: int = 1024,
max_soft_tokens: int = 280,
system_prompt: str = "",
) -> Iterator[str]:
messages: list[dict] = []
if system_prompt:
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
messages.extend(process_history(history))
messages.append({"role": "user", "content": process_new_user_message(message)})
template_kwargs: dict = {
"tokenize": True,
"return_dict": True,
"return_tensors": "pt",
"add_generation_prompt": True,
"load_audio_from_video": _has_media_type(messages, "video"),
"processor_kwargs": {"images_kwargs": {"max_soft_tokens": max_soft_tokens}},
}
if thinking:
template_kwargs["enable_thinking"] = True
inputs = processor.apply_chat_template(messages, **template_kwargs)
n_tokens = inputs["input_ids"].shape[1]
if n_tokens > MAX_INPUT_TOKENS:
msg = f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
raise gr.Error(msg)
yield from _generate_on_gpu(inputs=inputs, max_new_tokens=max_new_tokens, thinking=thinking)
examples = [
# --- Text-only examples ---
[
{
"text": "What is the capital of France?",
"files": [],
}
],
[
{
"text": "What is the water formula?",
"files": [],
}
],
[
{
"text": "Explain quantum entanglement in simple terms.",
"files": [],
}
],
[
{
"text": "I want to do a car wash that is 50 meters away, should I walk or drive?",
"files": [],
}
],
[
{
"text": "Write a poem about beer with 4 stanzas. Format the title as an H2 markdown heading and bold the first line of each stanza.",
"files": [],
}
],
# --- Single-image examples ---
[
{
"text": "Describe this image.",
"files": ["https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg"],
}
],
[
{
"text": "What is the city in this image? Describe what you see.",
"files": ["https://imgmd.net/images/v1/guia/1698673/rio-de-janeiro-4-c.jpg"],
}
],
# --- Multi-image examples ---
[
{
"text": "What are the key similarities between these three images?",
"files": [
"https://news.bbc.co.uk/media/images/38107000/jpg/_38107299_ronaldogoal_ap_300.jpg",
"https://ogimg.infoglobo.com.br/in/12547538-502-0e0/FT1086A/94-8705-14.jpg",
"https://amazonasatual.com.br/wp-content/uploads/2021/01/Pele.jpg",
],
}
],
# --- Audio examples ---
[
{
"text": "Transcribe the audio.",
"files": [
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3"
],
}
],
[
{
"text": "Translate to Dutch.",
"files": [
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3"
],
}
],
# --- Video examples ---
[
{
"text": "What is happening in this video?",
"files": ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"],
}
],
]
demo = gr.ChatInterface(
fn=generate,
validator=validate_input,
chatbot=gr.Chatbot(
scale=1,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
],
reasoning_tags=[(THINKING_START, THINKING_END)],
),
textbox=gr.MultimodalTextbox(
sources=["upload", "microphone"],
file_types=[*IMAGE_FILE_TYPES, *AUDIO_FILE_TYPES, *VIDEO_FILE_TYPES],
file_count="multiple",
autofocus=True,
),
multimodal=True,
additional_inputs=[
gr.Checkbox(label="Thinking", value=False),
gr.Slider(label="Max New Tokens", minimum=100, maximum=4000, step=10, value=2000),
gr.Dropdown(
label="Image Token Budget",
info="Higher values preserve more visual detail (useful for OCR/documents). Lower values are faster.",
choices=[70, 140, 280, 560, 1120],
value=280,
),
gr.Textbox(label="System Prompt", value=""),
],
additional_inputs_accordion=gr.Accordion("Settings", open=True),
stop_btn=False,
title="Gemma 4 E4B It",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch(css_paths="style.css", max_file_size="20MB")
@@ -0,0 +1,362 @@
# This file was autogenerated by uv via the following command:
# uv export --no-hashes --no-dev --group hf-spaces --no-emit-package typer-slim --no-emit-package spaces -o requirements.txt
accelerate==1.13.0
# via gemma-4-e4b-it
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.5
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-doc==0.0.4
# via
# fastapi
# typer
annotated-types==0.7.0
# via pydantic
anyio==4.13.0
# via
# gradio
# httpx
# mcp
# sse-starlette
# starlette
attrs==26.1.0
# via
# aiohttp
# jsonschema
# referencing
audioop-lts==0.2.2 ; python_full_version >= '3.13'
# via gradio
brotli==1.2.0
# via gradio
certifi==2026.2.25
# via
# httpcore
# httpx
# requests
cffi==2.0.0 ; platform_python_implementation != 'PyPy'
# via cryptography
charset-normalizer==3.4.7
# via requests
click==8.3.2
# via
# typer
# uvicorn
colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.7
# via pyjwt
datasets==4.8.4
dill==0.4.1
# via
# datasets
# multiprocess
fastapi==0.136.0
# via gradio
filelock==3.28.0
# via
# datasets
# huggingface-hub
# torch
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec==2026.2.0
# via
# datasets
# gradio-client
# huggingface-hub
# torch
gradio==6.12.0
# via
# gemma-4-e4b-it
# spaces
gradio-client==2.4.1
# via
# gradio
# hf-gradio
groovy==0.1.2
# via gradio
h11==0.16.0
# via
# httpcore
# uvicorn
hf-gradio==0.4.0
# via gradio
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
# via huggingface-hub
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via
# datasets
# gradio
# gradio-client
# huggingface-hub
# mcp
# safehttpx
# spaces
httpx-sse==0.4.3
# via mcp
huggingface-hub==1.11.0
# via
# accelerate
# datasets
# gradio
# gradio-client
# tokenizers
# transformers
idna==3.11
# via
# anyio
# httpx
# requests
# yarl
jinja2==3.1.6
# via
# gradio
# torch
jsonschema==4.26.0
# via mcp
jsonschema-specifications==2025.9.1
# via jsonschema
markdown-it-py==4.0.0
# via rich
markupsafe==3.0.3
# via
# gradio
# jinja2
mcp==1.27.0
# via gradio
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
multidict==6.7.1
# via
# aiohttp
# yarl
multiprocess==0.70.19
# via datasets
networkx==3.6.1
# via torch
numpy==2.4.4
# via
# accelerate
# datasets
# gradio
# pandas
# torchvision
# transformers
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
orjson==3.11.8
# via gradio
packaging==26.1
# via
# accelerate
# datasets
# gradio
# gradio-client
# huggingface-hub
# spaces
# transformers
pandas==3.0.2
# via
# datasets
# gradio
pillow==12.2.0
# via
# gradio
# torchvision
propcache==0.4.1
# via
# aiohttp
# yarl
psutil==5.9.8
# via
# accelerate
# spaces
pyarrow==23.0.1
# via datasets
pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
# via cffi
pydantic==2.12.5
# via
# fastapi
# gradio
# mcp
# pydantic-settings
# spaces
pydantic-core==2.41.5
# via pydantic
pydantic-settings==2.13.1
# via mcp
pydub==0.25.1
# via gradio
pygments==2.20.0
# via rich
pyjwt==2.12.1
# via mcp
python-dateutil==2.9.0.post0
# via pandas
python-dotenv==1.2.2
# via pydantic-settings
python-multipart==0.0.26
# via
# gradio
# mcp
pytz==2026.1.post1
# via gradio
pywin32==311 ; sys_platform == 'win32'
# via mcp
pyyaml==6.0.3
# via
# accelerate
# datasets
# gradio
# huggingface-hub
# transformers
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2026.4.4
# via transformers
requests==2.33.1
# via
# datasets
# spaces
rich==15.0.0
# via typer
rpds-py==0.30.0
# via
# jsonschema
# referencing
safehttpx==0.1.7
# via gradio
safetensors==0.7.0
# via
# accelerate
# transformers
semantic-version==2.10.0
# via gradio
setuptools==82.0.1
# via torch
shellingham==1.5.4
# via typer
six==1.17.0
# via python-dateutil
sse-starlette==3.3.4
# via mcp
starlette==1.0.0
# via
# fastapi
# gradio
# mcp
# sse-starlette
sympy==1.14.0
# via torch
tokenizers==0.22.2
# via transformers
tomlkit==0.14.0
# via gradio
torch==2.9.1
# via
# accelerate
# gemma-4-e4b-it
# torchvision
torchcodec==0.9.1
# via gemma-4-e4b-it
torchvision==0.24.1
# via gemma-4-e4b-it
tqdm==4.67.3
# via
# datasets
# huggingface-hub
# transformers
transformers==5.5.4
# via gemma-4-e4b-it
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
typer==0.24.1
# via
# gradio
# hf-gradio
# huggingface-hub
# transformers
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# fastapi
# gradio
# gradio-client
# huggingface-hub
# mcp
# pydantic
# pydantic-core
# referencing
# spaces
# starlette
# torch
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings
tzdata==2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
# via pandas
urllib3==2.6.3
# via requests
uvicorn==0.44.0
# via
# gradio
# mcp
xxhash==3.6.0
# via datasets
yarl==1.23.0
# via aiohttp
@@ -0,0 +1,33 @@
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_gemma4 import *
from .feature_extraction_gemma4 import *
from .image_processing_gemma4 import *
from .image_processing_pil_gemma4 import *
from .modeling_gemma4 import *
from .processing_gemma4 import *
from .video_processing_gemma4 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
@@ -0,0 +1,352 @@
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from huggingface_hub.dataclasses import strict
from ...configuration_utils import PreTrainedConfig
from ...utils import auto_docstring, logging
from ...utils.type_validators import interval
logger = logging.get_logger(__name__)
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
@strict
class Gemma4AudioConfig(PreTrainedConfig):
r"""
subsampling_conv_channels (`list[int]`, defaults to `[128, 32]`):
Channel sizes for the convolutional layers in the Sub-sample Convolution Projection.
residual_weight (`float`, defaults to `0.5`):
Scaling applied to hidden_states prior to combining with the residual in the feedforward.
attention_chunk_size (`int`, defaults to `12`):
The sub-sequence size for attention processing.
attention_context_left (`int`, defaults to `13`):
The leftward context size for the attention chunk.
attention_context_right (`int`, defaults to `0`):
The rightward context size for the attention chunk.
attention_logit_cap (`float`, defaults to `50.0`):
Cap applied to attention weights.
attention_invalid_logits_value (`float`, defaults to `1e-9`):
Value to use for invalid logits in attention.
use_clipped_linears (`bool`, defaults to `True`):
If true, apply clipping to the Linear layers, drawing bounds from the model checkpoint.
gradient_clipping (`float`, defaults to `1e10`):
Clipping value used to stabilize extremely large gradient values.
output_proj_dims (`int`, defaults to `1536`):
Dimension of the final linear projection from `hidden_size` to the model's output.
"""
model_type = "gemma4_audio"
hidden_size: int = 1024
num_hidden_layers: int = 12
num_attention_heads: int = 8
hidden_act: str = "silu"
# subsampling parameters
subsampling_conv_channels: list[int] | tuple[int, int] = (128, 32)
# conformer parameters
conv_kernel_size: int = 5
residual_weight: float = 0.5
attention_chunk_size: int = 12
attention_context_left: int = 13
attention_context_right: int = 0
attention_logit_cap: float = 50.0
attention_invalid_logits_value: float = -1.0e9
use_clipped_linears: bool = True
rms_norm_eps: float = 1e-6
gradient_clipping: float = 1e10
output_proj_dims: int = 1536
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
def __post_init__(self, **kwargs):
# JSON serialization converts tuples to lists, convert back
if isinstance(self.subsampling_conv_channels, tuple):
self.subsampling_conv_channels = list(self.subsampling_conv_channels)
super().__post_init__(**kwargs)
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
@strict
class Gemma4TextConfig(PreTrainedConfig):
r"""
use_bidirectional_attention (`str`, *optional*):
Controls bidirectional attention behavior. When set to `"vision"`, vision tokens
attend bidirectionally while text tokens use causal attention. When set to `"all"`,
all tokens use bidirectional attention.
vocab_size_per_layer_input (`int`, defaults to 262144):
Vocabulary size for the per-layer input embeddings (PLE). Used by models with
per-layer residual streams where a smaller embedding is added at each decoder layer.
hidden_size_per_layer_input (`int`, defaults to 256):
Per-layer hidden dimension for the PLE system. The actual embedding weight has shape
`[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
because all layers are packed into a single table. See the [Gemma4](https://huggingface.co/docs/transformers/main/en/model_doc/gemma4#per-layer-embeddings-ple) docs
for a description of the full PLE pipeline.
num_global_key_value_heads (`int`, *optional*):
Number of key-value heads for global (full) attention layers. If `None`, defaults
to `num_key_value_heads`.
global_head_dim (`int`, defaults to 512):
Dimension of each attention head in global (full) attention layers.
attention_k_eq_v (`bool`, defaults to `False`):
Whether keys and values share the same projection weights. When `True`, the key
projection output is reused as the value projection.
num_kv_shared_layers (`int`, defaults to 0):
Number of consecutive decoder layers that share the same key-value projections.
A value of 0 means no sharing (each layer has independent KV projections).
enable_moe_block (`bool`, defaults to `False`):
Whether to enable Mixture-of-Experts (MoE) blocks in the decoder layers. When
`True`, eligible layers will use a sparse MoE feed-forward network.
use_double_wide_mlp (`bool`, defaults to `False`):
Whether to use a double-width MLP with fused gate and up projections.
top_k_experts (`int`, *optional*):
Number of experts activated per token in MoE layers. Only used when
`enable_moe_block=True`.
moe_intermediate_size (`int`, *optional*):
Intermediate (hidden) size of each expert's feed-forward network in MoE layers.
Only used when `enable_moe_block=True`.
"""
model_type = "gemma4_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.experts.gate_up_proj": "packed_colwise",
"layers.*.experts.down_proj": "rowwise",
"layers.*.experts": "moe_tp_experts",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
vocab_size: int = 262_144
hidden_size: int = 2304
intermediate_size: int = 9216
num_hidden_layers: int = 30
num_attention_heads: int = 8
num_key_value_heads: int = 4
head_dim: int = 256
hidden_activation: str = "gelu_pytorch_tanh"
max_position_embeddings: int = 131_072
initializer_range: float = 0.02
rms_norm_eps: float = 1e-6
use_cache: bool = True
pad_token_id: int | None = 0
eos_token_id: int | list[int] | None = 1
bos_token_id: int | None = 2
tie_word_embeddings: bool = True
rope_parameters: dict | None = None
attention_bias: bool = False
attention_dropout: int | float | None = 0.0
sliding_window: int = 512
layer_types: list[str] | None = None
final_logit_softcapping: float | None = None
use_bidirectional_attention: Literal["all", "vision"] | None = None
vocab_size_per_layer_input: int = 262_144
hidden_size_per_layer_input: int = 256
num_global_key_value_heads: int | None = None
global_head_dim: int = 512
attention_k_eq_v: bool = False
num_kv_shared_layers: int = 0
enable_moe_block: bool = False
use_double_wide_mlp: bool = False
num_experts: int | None = None
top_k_experts: int | None = None
moe_intermediate_size: int | None = None
def __post_init__(self, **kwargs):
if self.use_bidirectional_attention == "all":
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
if self.layer_types is None:
sliding_window_pattern = 6 # by default 5:1
self.layer_types = [
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
for i in range(self.num_hidden_layers)
]
if self.layer_types and (last_layer_type := self.layer_types[-1]) != "full_attention":
logger.warning(
f"Last layer must use `full_attention`, but got `{last_layer_type}`. Forcing last layer to `full_attention`."
)
self.layer_types[-1] = "full_attention"
default_rope_params: dict[Literal["full_attention", "sliding_attention"] : dict[str, Any]] = {
"sliding_attention": {"rope_type": "default", "rope_theta": 10_000.0},
"full_attention": {"rope_type": "proportional", "partial_rotary_factor": 0.25, "rope_theta": 1_000_000.0},
}
if self.rope_parameters is None:
self.rope_parameters = default_rope_params
super().__post_init__(**kwargs)
def convert_rope_params_to_dict(self, **kwargs):
# No need to handle BC for new models, because they have no old-format `rope_scaling`
return kwargs
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
@strict
class Gemma4VisionConfig(PreTrainedConfig):
r"""
pooling_kernel_size (`int`, *optional*):
Spatial pooling kernel size applied after patchification.
position_embedding_size (`int`, defaults to 10240):
Maximum number of position embeddings for the vision encoder. Controls the size of
the learned 2D position embedding table used by the patch embedder.
use_clipped_linears (`bool`, defaults to `False`):
Whether to use weight-clipped linear layers. When enabled, linear layer weights are
clamped to a fixed range during the forward pass to improve numerical stability.
standardize (`bool`, defaults to `False`):
If true, applies a bias and scale to the soft tokens returned from the pooler.
"""
model_type = "gemma4_vision"
base_model_tp_plan = {
"encoder.layers.*.self_attn.q_proj": "colwise",
"encoder.layers.*.self_attn.k_proj": "colwise",
"encoder.layers.*.self_attn.v_proj": "colwise",
"encoder.layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
"encoder.layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
"encoder.layers.*.self_attn.o_proj": "rowwise",
"encoder.layers.*.mlp.gate_proj": "colwise",
"encoder.layers.*.mlp.up_proj": "colwise",
"encoder.layers.*.mlp.down_proj": "rowwise",
}
default_theta = 100.0
hidden_size: int = 768
intermediate_size: int = 3072
num_hidden_layers: int = 16
num_attention_heads: int = 12
num_key_value_heads: int = 12
head_dim: int = 64
hidden_activation: str = "gelu_pytorch_tanh"
rms_norm_eps: float = 1e-6
max_position_embeddings: int = 131_072
attention_bias: bool | None = False
attention_dropout: float | None = 0.0
rope_parameters: dict | None = None
pooling_kernel_size: int = 3
patch_size: int = 16
position_embedding_size: int = 10 * 1024
use_clipped_linears: bool = False
standardize: bool = False
initializer_range: float = 0.02
def __post_init__(self, **kwargs):
if self.rope_parameters is None:
self.rope_parameters = {"rope_type": "default", "rope_theta": 100.0}
super().__post_init__(**kwargs)
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
@strict
class Gemma4Config(PreTrainedConfig):
r"""
boi_token_id (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_id (`int`, *optional*, defaults to 258882):
The end-of-image token index to wrap the image prompt.
boa_token_id (`int`, *optional*, defaults to 256000):
The begin-of-audio token index to wrap the audio prompt.
eoa_token_index (`int`, *optional*, defaults to 258883):
The end-of-audio token index to wrap the audio prompt.
Example:
```python
>>> from transformers import (
>>> Gemma4AudioConfig,
>>> Gemma4Config,
>>> Gemma4ForConditionalGeneration,
>>> Gemma4TextConfig,
>>> Gemma4VisionConfig,
>>> )
>>> # Initializing a Gemma 4 Audio config.
>>> audio_config = Gemma4AudioConfig()
>>> # Initializing a Gemma 4 Text config.
>>> text_config = Gemma4TextConfig()
>>> # Initializing a Gemma 4 vision config.
>>> vision_config = Gemma4VisionConfig()
>>> # Initializing a Gemma 4 config similar to google/gemma-4-e2b-it
>>> configuration = Gemma4Config(text_config, vision_config, audio_config)
>>> # Initializing a model from the google/gemma-4-e2b-it configuration
>>> model = Gemma4ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma4"
sub_configs = {
"text_config": Gemma4TextConfig,
"vision_config": Gemma4VisionConfig,
"audio_config": Gemma4AudioConfig,
}
text_config: Gemma4TextConfig | dict[str, Any] | None = None
vision_config: Gemma4VisionConfig | dict[str, Any] | None = None
audio_config: Gemma4AudioConfig | dict[str, Any] | None = None
boi_token_id: int | None = 255_999
eoi_token_id: int | None = 258_882
image_token_id: int | None = 258_880
video_token_id: int | None = 258_884
boa_token_id: int | None = 256_000
eoa_token_index: int | None = 258_883
audio_token_id: int | None = 258_881
initializer_range: float | None = 0.02
tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if self.text_config is None:
self.text_config = Gemma4TextConfig()
logger.info("text_config is None. Using default Gemma4TextConfig.")
elif isinstance(self.text_config, dict):
self.text_config = Gemma4TextConfig(**self.text_config)
if self.vision_config is None:
logger.info("vision_config is None. Gemma4Model.vision_tower will not be initialized.")
if isinstance(self.vision_config, dict):
self.vision_config = Gemma4VisionConfig(**self.vision_config)
if self.audio_config is None:
logger.info("audio_config is None. Gemma4Model.audio_tower will not be initialized.")
if isinstance(self.audio_config, dict):
self.audio_config = Gemma4AudioConfig(**self.audio_config)
super().__post_init__(**kwargs)
__all__ = ["Gemma4AudioConfig", "Gemma4Config", "Gemma4TextConfig", "Gemma4VisionConfig"]
@@ -0,0 +1,298 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from collections.abc import Sequence
import numpy as np
from ...audio_utils import mel_filter_bank, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
logger = logging.get_logger(__name__)
def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
"""A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
if array.ndim != 2:
raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
if dimension != -1 and dimension != array.ndim - 1:
raise ValueError("This unfold implementation only supports unfolding the last dimension.")
batch_size, original_length = array.shape
num_frames = (original_length - size) // step + 1
if num_frames <= 0:
return np.zeros((batch_size, 0, size), dtype=array.dtype)
output_shape = (batch_size, num_frames, size)
output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
class Gemma4AudioFeatureExtractor(SequenceFeatureExtractor):
"""An audio feature extractor Universal Speech Models https://huggingface.co/papers/2303.01037.
Args:
feature_size (`int`, *optional*, defaults to 128):
The feature dimension of the extracted features.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention mask for the generated MEL spectrograms.
frame_length_ms (`float`, *optional*, defaults to 20.0):
The length of a frame in milliseconds.
hop_length_ms (`float`, *optional*, defaults to 10.0):
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
min_frequency (`float`, *optional*, defaults to 0.0):
The minimum frequency (in Hz) for the Mel filterbank.
max_frequency (`float`, *optional*, defaults to 8000.0):
The maximum frequency (in Hz) for the Mel filterbank.
preemphasis (`float`, *optional*, defaults to 0.0):
The preemphasis coefficient.
preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
Whether to use HTK-style preemphasis.
fft_overdrive (`bool`, *optional*, defaults to `False`):
Whether to use FFT overdrive.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 0.0001 to add dithering with a normal distribution centered
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
The value 0.0 means no dithering.
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
the high log_mel_fbank values for signals with hard-zero sections,
when VAD cutoff is present in the signal.
input_scale_factor (`float`, *optional*, defaults to 1.0):
Scaling factor applied to the input waveform.
mel_floor (`float`, *optional*, defaults to 0.001):
Minimum value for Mel spectrograms to avoid log(0).
per_bin_mean (`Optional[Sequence[float]]`, *optional*):
Mean values for per-bin normalization.
per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
Standard deviation values for per-bin normalization.
"""
model_input_names = ["input_features", "input_features_mask"]
def __init__(
self,
feature_size: int = 128,
sampling_rate: int = 16_000,
padding_value: float = 0.0,
return_attention_mask: bool = True,
frame_length_ms: float = 20.0,
hop_length_ms: float = 10.0,
min_frequency: float = 0.0,
max_frequency: float = 8000.0,
preemphasis: float = 0.0,
preemphasis_htk_flavor: bool = True,
fft_overdrive: bool = False,
dither: float = 0.0,
input_scale_factor: float = 1.0,
mel_floor: float = 1e-3,
per_bin_mean: Sequence[float] | None = None,
per_bin_stddev: Sequence[float] | None = None,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.min_frequency = min_frequency
self.max_frequency = max_frequency
self.preemphasis = preemphasis
self.preemphasis_htk_flavor = preemphasis_htk_flavor
self.fft_overdrive = fft_overdrive
self.dither = dither
self.input_scale_factor = input_scale_factor
self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
self.mel_floor = np.array(mel_floor, dtype=np.float64)
fft_length = 2 ** math.ceil(math.log2(self.frame_length))
if self.fft_overdrive:
fft_length *= 2
self.fft_length = fft_length
# Use periodic Hann window, matching sl.STFT default (signal.hann_window)
# For even frame_length: window[n] = 0.5 - 0.5 * cos(2*pi*n / frame_length)
self.window = window_function(self.frame_length).astype(np.float32)
# Use HuggingFace's mel_filter_bank for compatibility.
# Suppress the expected warning about all-zero upper mel filters;
# with fft_length=512 (257 bins) and 128 mel filters the uppermost
# triangular filter falls between frequency bins, which is harmless.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.mel_filters = mel_filter_bank(
num_frequency_bins=self.fft_length // 2 + 1,
num_mel_filters=feature_size,
min_frequency=min_frequency,
max_frequency=max_frequency,
sampling_rate=self.sampling_rate,
norm=None,
mel_scale="htk",
)
if per_bin_mean is not None:
self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
else:
self.per_bin_mean = None
if per_bin_stddev is not None:
self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
else:
self.per_bin_stddev = None
def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
""""""
if waveform.ndim == 1: # If single waveform, add batch dimension
waveform = np.expand_dims(waveform, axis=0)
if self.dither > 0.0:
waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
if self.input_scale_factor != 1.0:
waveform = waveform * self.input_scale_factor
# Semicausal time padding: prepend frame_length // 2 zeros so that the
# first STFT frame is centered at t=0, matching sl.STFT(time_padding='semicausal').
pad_left = self.frame_length // 2
waveform = np.pad(waveform, ((0, 0), (pad_left, 0)), mode="constant")
attention_mask = np.pad(attention_mask, (pad_left, 0), mode="constant", constant_values=0)
frame_size_for_unfold = self.frame_length + 1
# NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
if self.preemphasis > 0.0:
if self.preemphasis_htk_flavor:
first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
else:
frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
else:
frames = frames_to_process[..., :-1]
# Apply window, then RFFT. np.fft.rfft with n=fft_length implicitly
# right-pads frames to fft_length.
frames = frames * self.window # Broadcasting window
stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
magnitude_spec = np.abs(stft)
mel_spec = np.matmul(magnitude_spec, self.mel_filters)
log_mel_spec = np.log(mel_spec + self.mel_floor)
if self.per_bin_mean is not None:
log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
if self.per_bin_stddev is not None:
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
mel_spectrogram = log_mel_spec.squeeze(0)
num_mel_frames = mel_spectrogram.shape[0]
# Build a frame-aware mask: a mel frame is valid only when every sample
# in its analysis window [i*hop, i*hop + frame_size - 1] is real audio.
# We check this by looking at the last sample of each frame's window.
frame_end_indices = np.arange(num_mel_frames) * self.hop_length + frame_size_for_unfold - 1
mask = attention_mask[frame_end_indices].astype(bool)
return mel_spectrogram, mask
def __call__(
self,
raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
padding: bool | str | PaddingStrategy = "longest",
max_length: int | None = 480_000,
truncation: bool = True,
pad_to_multiple_of: int | None = 128,
return_tensors: str | TensorType | None = None,
return_attention_mask: bool | None = True,
**kwargs,
) -> BatchFeature:
"""Creates a batch of MEL spectrograms from the provided raw speech.
This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
`transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
carefully when selecting an audio feature extractor, especially with pre-trained models.
Args:
raw_speech:
The audio for which MEL spectrograms are created.
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
The padding strategy to use for batches of audio with different lengths.
max_length (`int`, *optional*, defaults to 480000):
If provided, defines the maximum length of the audio to allow. Audio longer than this will be
truncated if `truncation=True`.
truncation (`bool`, *optional*, defaults to `True`):
Whether or not to truncate audio above `max_length`.
pad_to_multiple_of (`int`, *optional*, defaults to 128):
When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
The type of tensors to return (e.g., NumPy, or Torch).
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention mask for the generated MEL spectrograms.
"""
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
is_batched = is_batched_numpy or is_batched_sequence
if is_batched:
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
if not is_batched: # always return a batch
raw_speech = [np.asarray([raw_speech])]
batched_speech = self.pad(
BatchFeature({"input_features": raw_speech}),
padding=padding,
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
prepared_speech = []
prepared_speech_mask = []
for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
speech, mask = self._extract_spectrogram(speech.T, mask)
prepared_speech.append(speech.astype(np.float32))
prepared_speech_mask.append(mask)
prepared_speech = [speech * mask[..., None] for speech, mask in zip(prepared_speech, prepared_speech_mask)]
return BatchFeature(
{"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
tensor_type=return_tensors,
)
__all__ = ["Gemma4AudioFeatureExtractor"]
@@ -0,0 +1,220 @@
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision.transforms.v2 import functional as F
from ...image_processing_backends import TorchvisionBackend
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput, PILImageResampling
from ...processing_utils import ImagesKwargs, Unpack
from ...utils import TensorType, auto_docstring, logging
from .image_processing_pil_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size
logger = logging.get_logger(__name__)
# Copied from transformers.models.siglip2.image_processing_siglip2.convert_image_to_patches
def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
"""
Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
"""
num_channels, image_height, image_width = image.shape
num_patches_height = image_height // patch_size
num_patches_width = image_width // patch_size
patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
patched_image = patched_image.permute(1, 3, 2, 4, 0)
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
return patched_image
# Adopted from Siglip2 (mask -> position ids)
def pad_along_first_dim(
image: "torch.Tensor", positions: "torch.Tensor", target_length: int
) -> tuple["torch.Tensor", "torch.Tensor"]:
"""
Pad the tensor along the first dimension.
"""
current_length = image.shape[0]
padding_length = target_length - current_length
if padding_length > 0:
padding = [0, 0] * (image.ndim - 1) + [0, padding_length]
pos_padding = (0, 0, 0, padding_length)
image = torch.nn.functional.pad(image, padding, mode="constant", value=0)
positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1)
return image, positions
class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False):
"""
patch_size (`int`, *optional*):
Size of each image patch in pixels.
max_soft_tokens (`int`, *optional*):
Maximum number of soft (vision) tokens per image.
Must be one of {70, 140, 280, 560, 1120}.
pooling_kernel_size (`int`, *optional*):
Spatial pooling kernel size applied after patchification.
"""
patch_size: int
max_soft_tokens: int
pooling_kernel_size: int
@auto_docstring(custom_intro="Constructs a Gemma4 image processor.")
class Gemma4ImageProcessor(TorchvisionBackend):
resample = PILImageResampling.BICUBIC
image_mean = [0.0, 0.0, 0.0]
image_std = [1.0, 1.0, 1.0]
size = None
default_to_square = True
do_convert_rgb = True
do_resize = True
do_rescale = True
do_normalize = False
patch_size = 16
max_soft_tokens = 280
pooling_kernel_size = 3
valid_kwargs = Gemma4ImageProcessorKwargs
model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"]
def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]):
super().__init__(**kwargs)
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
def _validate_preprocess_kwargs(self, **kwargs):
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
# parameter. Temporarily disable do_resize so the base validation
# doesn't require `size` to be set.
kwargs["do_resize"] = False
super()._validate_preprocess_kwargs(**kwargs)
def aspect_ratio_preserving_resize(
self,
image: torch.Tensor,
patch_size: int,
max_patches: int,
pooling_kernel_size: int,
resample: F.InterpolationMode,
) -> torch.Tensor:
height, width = image.shape[-2], image.shape[-1]
target_height, target_width = get_aspect_ratio_preserving_size(
height=height,
width=width,
patch_size=patch_size,
max_patches=max_patches,
pooling_kernel_size=pooling_kernel_size,
)
if target_height == height and target_width == width:
return image
return F.resize(
image,
size=[target_height, target_width],
interpolation=resample,
antialias=True,
)
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[Gemma4ImageProcessorKwargs],
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
resample: "PILImageResampling | F.InterpolationMode | int | None",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
return_tensors: str | TensorType | None,
patch_size: int | None = None,
max_soft_tokens: int | None = None,
pooling_kernel_size: int | None = None,
**kwargs,
) -> BatchFeature:
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
# Compute max_patches from max_soft_tokens and pooling_kernel_size
max_patches = max_soft_tokens * pooling_kernel_size**2
# Process each image individually: resize, rescale/normalize, patchify, pad.
# Images have different aspect ratios and thus different resized dimensions,
# so patchification and padding must happen per-image before stacking.
pixel_values = []
position_ids = []
num_soft_tokens_per_image = []
for image in images:
# Step 1: Aspect-ratio-preserving resize
if do_resize:
image = self.aspect_ratio_preserving_resize(
image=image,
patch_size=patch_size,
max_patches=max_patches,
pooling_kernel_size=pooling_kernel_size,
resample=resample,
)
# Step 2: Rescale pixel values (typically to [0, 1]) and optionally identity normalize
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
# Step 3: Patchify the image
# (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
patch_height = image.shape[-2] // patch_size
patch_width = image.shape[-1] // patch_size
patches = convert_image_to_patches(image, patch_size)
num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2)
# Step 5: Compute position IDs
device = image.device
patch_grid = torch.meshgrid(
torch.arange(patch_width, device=device),
torch.arange(patch_height, device=device),
indexing="xy",
)
stacked_grid = torch.stack(patch_grid, dim=-1)
real_positions = stacked_grid.reshape(patches.shape[0], 2)
# Step 6. Pad pacthes and positions to `max_patches`
patches, positions = pad_along_first_dim(patches, real_positions, max_patches)
pixel_values.append(patches)
position_ids.append(positions)
# Stack into batch tensors
pixel_values = torch.stack(pixel_values, dim=0) # (batch, max_patches, patch_pixels)
position_ids = torch.stack(position_ids, dim=0) # (batch, max_patches, 2)
data = {
"pixel_values": pixel_values,
"image_position_ids": position_ids,
"num_soft_tokens_per_image": num_soft_tokens_per_image,
}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma4ImageProcessor"]
@@ -0,0 +1,278 @@
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
from ...image_processing_backends import PilBackend
from ...image_processing_utils import BatchFeature
from ...image_transforms import resize
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, Unpack
from ...utils import TensorType, auto_docstring, is_vision_available, logging
if is_vision_available():
from ...image_utils import PILImageResampling
logger = logging.get_logger(__name__)
_SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120)
def get_aspect_ratio_preserving_size(
height: int,
width: int,
patch_size: int,
max_patches: int,
pooling_kernel_size: int,
) -> tuple[int, int]:
"""
Image is resized to preserve aspect ratio so it fits within the patch budget.
Target dimensions are the largest that:
1) Produce at most `max_patches` patches when patchified with `patch_size`
2) Have height and width divisible by `pooling_kernel_size * patch_size`
"""
total_px = height * width
target_px = max_patches * (patch_size**2)
factor = math.sqrt(target_px / total_px)
ideal_height = factor * height
ideal_width = factor * width
side_mult = pooling_kernel_size * patch_size
# Round down to nearest multiple of side_mult
target_height = int(math.floor(ideal_height / side_mult)) * side_mult
target_width = int(math.floor(ideal_width / side_mult)) * side_mult
# Handle edge cases where one or both dimensions round to 0
if target_height == 0 and target_width == 0:
raise ValueError(
"Attempting to resize to a 0 x 0 image. Resized height should be divisble by "
f"`pooling_kernel_size * patch_size`={pooling_kernel_size * patch_size}."
)
max_side_length = (max_patches // pooling_kernel_size**2) * side_mult
if target_height == 0:
target_height = side_mult
target_width = min(
int(math.floor(width / height)) * side_mult,
max_side_length,
)
elif target_width == 0:
target_width = side_mult
target_height = min(
int(math.floor(height / width)) * side_mult,
max_side_length,
)
if target_height * target_width > target_px:
raise ValueError(
f"Resizing [{height}x{width}] to [{target_height}x{target_width}] "
f"but this exceeds {max_patches} patches with patch_size {patch_size}"
)
return target_height, target_width
# Copied from transformers.models.siglip2.image_processing_pil_siglip2.convert_image_to_patches
def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray:
"""
Convert 3D array image of shape (num_channels, image_height, image_width) into 2D array of patches of shape
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
"""
num_channels, image_height, image_width = image.shape
num_patches_height = image_height // patch_size
num_patches_width = image_width // patch_size
patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
patched_image = patched_image.transpose(1, 3, 2, 4, 0)
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
return patched_image
# Adopted from Siglip2 (mask -> position ids)
def pad_along_first_dim(image: np.ndarray, positions: np.ndarray, target_length: int) -> tuple[np.ndarray, np.ndarray]:
"""
Pad the image along the first dimension.
"""
current_length = image.shape[0]
padding_length = target_length - current_length
if padding_length > 0:
paddings = [(0, padding_length)] + [(0, 0)] * (image.ndim - 1)
pos_paddings = [(0, padding_length), (0, 0)]
image = np.pad(image, paddings, mode="constant", constant_values=0)
positions = np.pad(positions, pos_paddings, mode="constant", constant_values=-1)
return image, positions
class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False):
"""
patch_size (`int`, *optional*):
Size of each image patch in pixels.
max_soft_tokens (`int`, *optional*):
Maximum number of soft (vision) tokens per image.
Must be one of {70, 140, 280, 560, 1120}.
pooling_kernel_size (`int`, *optional*):
Spatial pooling kernel size applied after patchification.
"""
patch_size: int
max_soft_tokens: int
pooling_kernel_size: int
@auto_docstring(custom_intro="Constructs a Gemma4 image processor.")
class Gemma4ImageProcessorPil(PilBackend):
valid_kwargs = Gemma4ImageProcessorKwargs
model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"]
do_resize = True
resample = PILImageResampling.BICUBIC
do_rescale = True
rescale_factor = 1 / 255
do_normalize = False
image_mean = [0.0, 0.0, 0.0]
image_std = [1.0, 1.0, 1.0]
do_convert_rgb = True
patch_size = 16
max_soft_tokens = 280
pooling_kernel_size = 3
def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]) -> None:
super().__init__(**kwargs)
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
def _validate_preprocess_kwargs(self, **kwargs):
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
# parameter. Temporarily disable do_resize so the base validation
# doesn't require `size` to be set.
kwargs["do_resize"] = False
super()._validate_preprocess_kwargs(**kwargs)
@auto_docstring
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[Gemma4ImageProcessorKwargs],
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def aspect_ratio_preserving_resize(
self,
image: np.ndarray,
patch_size: int,
max_patches: int,
pooling_kernel_size: int,
resample: PILImageResampling,
) -> np.ndarray:
height, width = image.shape[-2], image.shape[-1]
target_height, target_width = get_aspect_ratio_preserving_size(
height=height,
width=width,
patch_size=patch_size,
max_patches=max_patches,
pooling_kernel_size=pooling_kernel_size,
)
if target_height == height and target_width == width:
return image
return resize(
image,
size=(target_height, target_width),
resample=resample,
)
def _preprocess(
self,
images: list[np.ndarray],
do_resize: bool,
resample: "PILImageResampling | int | None",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
return_tensors: str | TensorType | None,
max_soft_tokens: int | None = None,
patch_size: int | None = None,
pooling_kernel_size: int | None = None,
**kwargs,
) -> BatchFeature:
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
# Compute max_patches from max_soft_tokens and pooling_kernel_size
max_patches = max_soft_tokens * pooling_kernel_size**2
# Process each image individually: resize, rescale/normalize, patchify, pad.
# Images have different aspect ratios and thus different resized dimensions,
# so patchification and padding must happen per-image before stacking.
pixel_values = []
position_ids = []
num_soft_tokens_per_image = []
for image in images:
# Step 1: Aspect-ratio-preserving resize
if do_resize:
image = self.aspect_ratio_preserving_resize(
image=image,
patch_size=patch_size,
max_patches=max_patches,
pooling_kernel_size=pooling_kernel_size,
resample=resample,
)
# Step 2: Rescale pixel values from [0, 255] to [0, 1]
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
# Step 3: Identity normalization because Gemma4 was trained with pixels in [0, 1]
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
# Step 4: Patchify the image
# image is (C, H, W) numpy array; add batch dimension for reshape
# (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
patches = convert_image_to_patches(image, patch_size)
num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2)
# Step 5: Compute position IDs
patch_height = image.shape[-2] // patch_size
patch_width = image.shape[-1] // patch_size
grid_x, grid_y = np.meshgrid(np.arange(patch_width), np.arange(patch_height), indexing="xy")
real_positions = np.stack([grid_x, grid_y], axis=-1).reshape(patches.shape[0], 2)
patches, positions = pad_along_first_dim(patches, real_positions, max_patches)
pixel_values.append(patches)
position_ids.append(positions)
# Stack into batch arrays and convert to tensors
pixel_values = np.stack(pixel_values, axis=0) # (batch, max_patches, patch_pixels)
position_ids = np.stack(position_ids, axis=0) # (batch, max_patches, 2)
data = {
"pixel_values": pixel_values,
"image_position_ids": position_ids,
"num_soft_tokens_per_image": num_soft_tokens_per_image,
}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma4ImageProcessorPil"]
@@ -0,0 +1,723 @@
# === HEADER (license + imports) ===
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma4.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
from ...integrations import use_experts_implementation, use_kernelized_func
from ...masking_utils import (
create_bidirectional_mask,
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_accelerate_available,
torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..auto.modeling_auto import AutoModel
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
if is_accelerate_available():
from accelerate.hooks import add_hook_to_module
@dataclass
@auto_docstring(
custom_intro="""
Base class for Gemma4 outputs, with hidden states and attentions.
"""
)
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
r"""
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
@dataclass
@auto_docstring(
custom_intro="""
Base class for Gemma4 outputs, with hidden states and attentions.
"""
)
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
r"""
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
...
@dataclass
@auto_docstring(
custom_intro="""
Base class for Gemma4 causal language model (or autoregressive) outputs.
"""
)
class Gemma4CausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
...
@dataclass
@auto_docstring
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
r"""
attention_mask (`torch.BoolTensor`, *optional*):
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
"""
attention_mask: torch.BoolTensor | None = None
class Gemma4ClippableLinear(nn.Module):
def __init__(
self,
...
class Gemma4RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
super().__init__()
self.eps = eps
self.with_scale = with_scale
if self.with_scale:
self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
def _norm(self, hidden_states: torch.Tensor):
mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
# Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
return hidden_states * torch.pow(mean_squared, -0.5)
...
class Gemma4AudioRelPositionalEncoding(nn.Module):
"""Sinusoidal relative positional encoding for the audio encoder.
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
"""
inv_timescales: torch.Tensor
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.context_size = (
config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
...
class Gemma4AudioAttention(nn.Module):
"""Chunked local attention with relative position bias"""
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_logits_soft_cap = config.attention_logit_cap
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.q_scale = (self.head_dim**-0.5) / math.log(2)
self.k_scale = math.log(1 + math.e) / math.log(2)
...
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
def __init__(self, in_channels, out_channels, norm_eps):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=1,
bias=False,
)
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
self.act = nn.ReLU()
...
class Gemma4AudioSubSampleConvProjection(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=1,
out_channels=config.subsampling_conv_channels[0],
norm_eps=config.rms_norm_eps,
)
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=config.subsampling_conv_channels[0],
out_channels=config.subsampling_conv_channels[1],
norm_eps=config.rms_norm_eps,
)
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
...
class Gemma4AudioFeedForward(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.act_fn = ACT2FN[config.hidden_act]
self.gradient_clipping = config.gradient_clipping
self.post_layer_scale = config.residual_weight
...
class Gemma4AudioCausalConv1d(nn.Conv1d):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# kernel_size: int,
# # cache_key: str,
# stride: int = 1,
# dilation: int = 1,
# bias: bool = True,
# ):
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
# self.cache_key = cache_key
...
class Gemma4AudioLightConv1d(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=config.conv_kernel_size,
groups=config.hidden_size,
bias=False,
)
...
class Gemma4AudioLayer(nn.Module):
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.feed_forward1 = Gemma4AudioFeedForward(config)
self.feed_forward2 = Gemma4AudioFeedForward(config)
self.self_attn = Gemma4AudioAttention(config, layer_idx)
self.lconv1d = Gemma4AudioLightConv1d(config)
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_out = Gemma4RMSNorm(config.hidden_size)
...
class Gemma4VisionPatchEmbedder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.position_embedding_size = config.position_embedding_size
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
"""Prepare patch positions map for matmul with positon embedding table."""
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
...
class Gemma4VisionPooler(nn.Module):
"""Scaling and optional spatial pooling for vision encodings"""
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.root_hidden_size = self.hidden_size**0.5
def _avg_pool_by_positions(
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
2D spatial pooling according to patch positions.
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
...
class Gemma4VisionMLP(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
...
class Gemma4VisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: Gemma4VisionConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
...
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
x (`torch.Tensor`): The tensor to embed.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
...
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
...
def apply_multidimensional_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
unsqueeze_dim: int = 2,
) -> torch.Tensor:
"""Applies multidimensional RoPE to inputs.
Args:
x (`torch.Tensor`): The tensor to embed.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
...
@use_kernelized_func(apply_rotary_pos_emb)
class Gemma4VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = 1.0
self.attention_dropout = self.config.attention_dropout
self.is_causal = False
...
class Gemma4VisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4VisionMLP(config)
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
def forward(
...
class Gemma4VisionEncoder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.num_layers = config.num_hidden_layers
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
self.layers = nn.ModuleList(
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
...
class Gemma4TextMLP(nn.Module):
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
...
class Gemma4TextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.layer_types = set(config.layer_types)
self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
self.rope_type: dict[str, str] = {}
for layer_type in self.layer_types:
...
@use_kernelized_func(apply_rotary_pos_emb)
class Gemma4TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.is_sliding = self.layer_type == "sliding_attention"
self.sliding_window = config.sliding_window if self.is_sliding else None
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
...
@use_experts_implementation
class Gemma4TextExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_activation]
def forward(
...
class Gemma4TextRouter(nn.Module):
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.scalar_root_size = self.hidden_size**-0.5
self.eps = config.rms_norm_eps
self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.scale = nn.Parameter(torch.ones(self.hidden_size))
self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
...
class Gemma4TextDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4TextMLP(config, layer_idx)
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.register_buffer("layer_scalar", torch.ones(1))
...
class Gemma4TextScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.scalar_embed_scale = embed_scale
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
...
@auto_docstring
class Gemma4PreTrainedModel(PreTrainedModel):
config: Gemma4Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = None # override
...
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
class Gemma4TextModel(Gemma4PreTrainedModel):
config: Gemma4TextConfig
input_modalities = ("text",)
_can_record_outputs = {
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
"hidden_states": Gemma4TextDecoderLayer,
"attentions": Gemma4TextAttention,
}
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
...
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Gemma4TextConfig
base_model_prefix = "model"
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
self.model = Gemma4TextModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Grab the ones from the child
...
def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable:
"""
This creates uni/bidirectional attention mask with sliding window.
"""
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
left_window_size, right_window_size = sliding_window
dist = q_idx - kv_idx
left_mask = (dist >= 0) & (dist < left_window_size)
right_mask = (dist < 0) & (-dist < right_window_size)
return left_mask | right_mask
return inner_mask
...
class Gemma4AudioModel(Gemma4PreTrainedModel):
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
config: Gemma4AudioConfig
main_input_name = "input_features"
base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
_can_record_outputs = {
"hidden_states": Gemma4AudioLayer,
"attentions": Gemma4AudioAttention,
}
def __init__(self, config: Gemma4AudioConfig):
super().__init__(config)
self.config = config
...
class Gemma4VisionModel(Gemma4PreTrainedModel):
"""The Gemma 4 Vision Encoder."""
config = Gemma4VisionConfig
_can_record_outputs = {
"hidden_states": Gemma4VisionEncoderLayer,
"attentions": Gemma4VisionAttention,
}
def __init__(self, config: Gemma4VisionConfig):
super().__init__(config)
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
self.encoder = Gemma4VisionEncoder(config)
self.pooler = Gemma4VisionPooler(config)
...
class Gemma4MultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language model space."""
def __init__(
self,
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
text_config: Gemma4TextConfig,
):
super().__init__()
self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
self.eps = multimodal_config.rms_norm_eps
self.text_hidden_size = text_config.hidden_size
self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
...
def token_type_ids_mask_function(
token_type_ids: torch.Tensor | None,
image_group_ids: torch.Tensor | None,
) -> Callable | None:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
"""
# Do not return an additional mask in this case
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
seq_length = image_group_ids.shape[-1]
...
def create_causal_mask_mapping(
config: PreTrainedConfig,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | None,
past_key_values: Cache | None,
position_ids: torch.Tensor | None,
mm_token_type_ids: torch.Tensor | None = None,
pixel_values: torch.FloatTensor | None = None,
is_training: bool = False,
is_first_iteration: bool | None = None,
**kwargs,
) -> dict:
"""
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
...
@auto_docstring(
custom_intro="""
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
language modeling head.
"""
)
class Gemma4Model(Gemma4PreTrainedModel):
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def __init__(self, config: Gemma4Config):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
...
@auto_docstring(
custom_intro="""
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
head.
"""
)
class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
accepts_loss_kwargs = False
base_model_prefix = "model"
def __init__(self, config: Gemma4Config):
super().__init__(config)
self.model = Gemma4Model(config)
...
@@ -0,0 +1,563 @@
# === HEADER (license + imports) ===
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
import torch
from torch import nn
from torch.nn import functional as F
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...integrations import use_kernelized_func
from ...masking_utils import (
create_bidirectional_mask,
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_accelerate_available,
logging,
torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..auto.modeling_auto import AutoModel
from ..gemma3.modeling_gemma3 import (
Gemma3Attention,
Gemma3DecoderLayer,
Gemma3ForCausalLM,
Gemma3MLP,
Gemma3RotaryEmbedding,
Gemma3TextModel,
Gemma3TextScaledWordEmbedding,
)
from ..gemma3n.modeling_gemma3n import (
Gemma3nCausalLMOutputWithPast,
Gemma3nForConditionalGeneration,
Gemma3nModel,
Gemma3nModelOutputWithPast,
Gemma3nMultimodalEmbedder,
Gemma3nPreTrainedModel,
Gemma3nRMSNorm,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..llama.modeling_llama import LlamaRotaryEmbedding
from ..mixtral.modeling_mixtral import MixtralExperts
from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
if is_accelerate_available():
pass
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast):
pass
class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast):
pass
@dataclass
@auto_docstring
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
r"""
attention_mask (`torch.BoolTensor`, *optional*):
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
...
class Gemma4ClippableLinear(nn.Module):
def __init__(
self,
config: Gemma4VisionConfig | Gemma4AudioConfig,
in_features: int,
out_features: int,
) -> None:
super().__init__()
self.use_clipped_linears = config.use_clipped_linears
self.linear = nn.Linear(in_features, out_features, bias=False)
if self.use_clipped_linears:
self.register_buffer("input_min", torch.tensor(-float("inf")))
self.register_buffer("input_max", torch.tensor(float("inf")))
...
class Gemma4RMSNorm(Gemma3nRMSNorm):
pass
class Gemma4AudioRelPositionalEncoding(nn.Module):
"""Sinusoidal relative positional encoding for the audio encoder.
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
"""
inv_timescales: torch.Tensor
def __init__(self, config: Gemma4AudioConfig):
...
class Gemma4AudioAttention(nn.Module):
"""Chunked local attention with relative position bias"""
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_logits_soft_cap = config.attention_logit_cap
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.q_scale = (self.head_dim**-0.5) / math.log(2)
self.k_scale = math.log(1 + math.e) / math.log(2)
...
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
def __init__(self, in_channels, out_channels, norm_eps):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=1,
bias=False,
)
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
self.act = nn.ReLU()
...
class Gemma4AudioSubSampleConvProjection(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=1,
out_channels=config.subsampling_conv_channels[0],
norm_eps=config.rms_norm_eps,
)
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=config.subsampling_conv_channels[0],
out_channels=config.subsampling_conv_channels[1],
norm_eps=config.rms_norm_eps,
)
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
...
class Gemma4AudioFeedForward(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.act_fn = ACT2FN[config.hidden_act]
self.gradient_clipping = config.gradient_clipping
self.post_layer_scale = config.residual_weight
...
class Gemma4AudioCausalConv1d(nn.Conv1d):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# kernel_size: int,
# # cache_key: str,
# stride: int = 1,
# dilation: int = 1,
# bias: bool = True,
# ):
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
# self.cache_key = cache_key
...
class Gemma4AudioLightConv1d(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=config.conv_kernel_size,
groups=config.hidden_size,
bias=False,
)
...
class Gemma4AudioLayer(nn.Module):
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.feed_forward1 = Gemma4AudioFeedForward(config)
self.feed_forward2 = Gemma4AudioFeedForward(config)
self.self_attn = Gemma4AudioAttention(config, layer_idx)
self.lconv1d = Gemma4AudioLightConv1d(config)
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_out = Gemma4RMSNorm(config.hidden_size)
...
class Gemma4VisionPatchEmbedder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.position_embedding_size = config.position_embedding_size
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
"""Prepare patch positions map for matmul with positon embedding table."""
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
...
class Gemma4VisionPooler(nn.Module):
"""Scaling and optional spatial pooling for vision encodings"""
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.root_hidden_size = self.hidden_size**0.5
def _avg_pool_by_positions(
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
2D spatial pooling according to patch positions.
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
...
class Gemma4VisionMLP(Gemma3MLP):
def __init__(self, config: Gemma4VisionConfig):
super().__init__(self, config)
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
def apply_multidimensional_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
unsqueeze_dim: int = 2,
...
class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding):
@staticmethod
def compute_default_rope_parameters(
config: Gemma4VisionConfig | None = None,
device: torch.device | None = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
...
class Gemma4VisionAttention(Gemma3Attention):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__(self, config, layer_idx)
del self.attn_logit_softcapping
del self.sliding_window
del self.is_sliding
self.scaling = 1.0
self.is_causal = False
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
...
class Gemma4VisionEncoderLayer(Gemma3DecoderLayer):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__(self, config, layer_idx)
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4VisionMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
...
class Gemma4VisionEncoder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.num_layers = config.num_hidden_layers
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
self.layers = nn.ModuleList(
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
...
class Gemma4TextMLP(Gemma3MLP):
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
super().__init__()
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding):
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
nn.Module.__init__(self)
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
...
@use_kernelized_func(apply_rotary_pos_emb)
class Gemma4TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.is_sliding = self.layer_type == "sliding_attention"
self.sliding_window = config.sliding_window if self.is_sliding else None
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
...
class Gemma4TextExperts(MixtralExperts):
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.num_experts = config.num_experts
self.intermediate_dim = config.moe_intermediate_size
self.act_fn = ACT2FN[config.hidden_activation]
class Gemma4TextRouter(nn.Module):
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.scalar_root_size = self.hidden_size**-0.5
...
class Gemma4TextDecoderLayer(Gemma3DecoderLayer):
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4TextMLP(config, layer_idx)
self.register_buffer("layer_scalar", torch.ones(1))
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.act_fn = ACT2FN[config.hidden_activation]
self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
...
class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
pass
# ---- Model Classes ----
class Gemma4PreTrainedModel(Gemma3nPreTrainedModel):
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
input_modalities = ("image", "text", "video", "audio")
_can_record_outputs = None # override
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
@torch.no_grad()
...
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
class Gemma4TextModel(Gemma3TextModel):
config: Gemma4TextConfig
_can_record_outputs = {
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
"hidden_states": Gemma4TextDecoderLayer,
"attentions": Gemma4TextAttention,
}
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
...
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
class Gemma4ForCausalLM(Gemma3ForCausalLM):
base_model_prefix = "model"
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
# Grab the ones from the child
self._keys_to_ignore_on_load_unexpected = [
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
]
class Gemma4AudioModel(Gemma4PreTrainedModel):
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
...
class Gemma4VisionModel(Gemma4PreTrainedModel):
"""The Gemma 4 Vision Encoder."""
config = Gemma4VisionConfig
_can_record_outputs = {
"hidden_states": Gemma4VisionEncoderLayer,
"attentions": Gemma4VisionAttention,
}
def __init__(self, config: Gemma4VisionConfig):
super().__init__(config)
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
self.encoder = Gemma4VisionEncoder(config)
self.pooler = Gemma4VisionPooler(config)
...
class Gemma4MultimodalEmbedder(Gemma3nMultimodalEmbedder):
def __init__(
self,
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
text_config: Gemma4TextConfig,
):
# Audio tower may use a different output dimension (output_proj_dims) than the
# internal hidden_size. Use the tower-specific dimension if specified.
super().__init__(multimodal_config, text_config)
del self.embedding
del self.hard_embedding_norm
del self.soft_embedding_norm
del self.vocab_offset
del self.vocab_size
...
def token_type_ids_mask_function(
token_type_ids: torch.Tensor | None,
image_group_ids: torch.Tensor | None,
) -> Callable | None:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
"""
# Do not return an additional mask in this case
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
seq_length = image_group_ids.shape[-1]
...
def create_causal_mask_mapping(
config: PreTrainedConfig,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | None,
past_key_values: Cache | None,
position_ids: torch.Tensor | None,
mm_token_type_ids: torch.Tensor | None = None,
pixel_values: torch.FloatTensor | None = None,
is_training: bool = False,
is_first_iteration: bool | None = None,
**kwargs,
) -> dict:
"""
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
...
@auto_docstring(
custom_intro="""
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
language modeling head.
"""
)
class Gemma4Model(Gemma3nModel):
def __init__(self, config: Gemma4Config):
super().__init__(config)
del self.vision_tower
del self.embed_vision
self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None
self.embed_vision = (
Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
...
@auto_docstring(
custom_intro="""
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
head.
"""
)
class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration):
base_model_prefix = "model"
def __init__(self, config: Gemma4Config):
super().__init__(config)
# Grab the ones from the child
self._keys_to_ignore_on_load_unexpected = [
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
...
@@ -0,0 +1,366 @@
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
from ...audio_utils import AudioInput
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput, make_nested_list_of_images
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import auto_docstring, is_vision_available, logging
from ...utils.import_utils import requires
from ...video_utils import VideoInput
if is_vision_available():
from .image_processing_pil_gemma4 import Gemma4ImageProcessorKwargs, get_aspect_ratio_preserving_size
logger = logging.get_logger(__name__)
class Gemma4ProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Gemma4ImageProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": True,
"return_mm_token_type_ids": True,
},
"images_kwargs": {
"do_convert_rgb": True,
},
"audio_kwargs": {},
"videos_kwargs": {"return_metadata": True},
}
@auto_docstring
@requires(backends=("vision",))
class Gemma4Processor(ProcessorMixin):
def __init__(
self,
feature_extractor,
image_processor,
tokenizer,
video_processor,
chat_template=None,
image_seq_length: int = 280,
audio_seq_length: int = 750,
audio_ms_per_token: int = 40,
**kwargs,
):
r"""
image_seq_length (`int`, *optional*, defaults to 280):
The number of soft tokens per image used for placeholder expansion.
audio_seq_length (`int`, *optional*, defaults to 750):
The maximum number of audio soft tokens per audio segment. Serves as an
upper-bound cap when dynamic audio token counts are computed.
audio_ms_per_token (`int`, *optional*, defaults to 40):
Milliseconds of audio per output soft token. Used to dynamically compute
the number of audio placeholder tokens as ``ceil(duration_ms / audio_ms_per_token)``.
The default of 40 comes from the SSCP convolution's 4× time reduction on 10ms frames.
"""
self.image_seq_length = image_seq_length
self.image_token_id = tokenizer.image_token_id
self.boi_token = tokenizer.boi_token
self.eoi_token = tokenizer.eoi_token
self.image_token = tokenizer.image_token
# FIXME: add the token to config and ask Ryan to re-upload
tokenizer.add_special_tokens({"additional_special_tokens": ["<|video|>"]})
self.video_token = "<|video|>"
self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
# Audio token handling, mirroring the vision pattern.
# audio_seq_length serves as the maximum cap on the number of audio soft tokens
# any single audio segment can produce. With dynamic audio tokens, the actual
# number of placeholders inserted per audio is computed from the audio duration.
self.audio_seq_length = audio_seq_length
# Milliseconds of audio per output soft token. The default of 40 comes from the
# SSCP convolution's 4× time reduction applied to 10ms mel spectrogram frames.
self.audio_ms_per_token = audio_ms_per_token
self.audio_token_id = getattr(tokenizer, "audio_token_id", None)
self.audio_token = getattr(tokenizer, "audio_token", None)
self.boa_token = getattr(tokenizer, "boa_token", None)
self.eoa_token = getattr(tokenizer, "eoa_token", None)
super().__init__(
feature_extractor=feature_extractor,
image_processor=image_processor,
tokenizer=tokenizer,
video_processor=video_processor,
chat_template=chat_template,
**kwargs,
)
@auto_docstring
def __call__(
self,
images: ImageInput | None = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
audio: AudioInput | None = None,
videos: VideoInput | None = None,
**kwargs: Unpack[Gemma4ProcessorKwargs],
) -> BatchFeature:
if text is None and images is None and audio is None and videos is None:
raise ValueError("Provide at least one of `text`, `images`, `audio`, or `videos`.")
output_kwargs = self._merge_kwargs(
Gemma4ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
image_inputs = {}
if images is not None:
images = self.image_processor.fetch_images(images)
batched_images = make_nested_list_of_images(images)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
num_soft_tokens = image_inputs.pop("num_soft_tokens_per_image")
# Create empty text to be replaced with placeholders
if not text:
text = [" ".join([self.image_token] * len(images)) for images in batched_images]
if len(batched_images) != len(text):
raise ValueError(
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
)
replacements = [f"{self.boi_token}{self.image_token * n}{self.eoi_token}" for n in num_soft_tokens]
replacements_iter = iter(replacements)
# Expand image_token placeholders to per-image soft token sequences.
# re.sub never re-scans replaced text, so it is safe
pattern = re.escape(self.image_token)
text = [re.sub(pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
# Process video inputs in same way
video_inputs = {}
if videos is not None:
video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
num_video_tokens = video_inputs.pop("num_soft_tokens_per_video")
# If user has not requested video metadata, pop it so it isn't returned
if not kwargs.get("return_metadata"):
video_metadata = video_inputs.pop("video_metadata")
else:
video_metadata = video_inputs["video_metadata"]
video_replacements = []
for metadata, n_tokens in zip(video_metadata, num_video_tokens):
if metadata.fps is None:
logger.warning_once(
"Gemma 4 requires frame timestamps to construct prompts, but the `fps` of the input video "
"could not be inferred. Probably `video_metadata` was missing from inputs and you passed "
"pre-sampled frames. Defaulting to `fps=24`. Please provide `video_metadata` for more "
"accurate results."
)
metadata.fps = 24 if metadata.fps is None else metadata.fps
# mm:ss format for timestamps
timestamp_str = [
f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" for seconds in metadata.timestamps
]
video_replacements.append(
" ".join(
[f"{t} {self.boi_token}{self.video_token * n_tokens}{self.eoi_token}" for t in timestamp_str]
)
)
video_replacements = iter(video_replacements)
pattern = re.escape(self.video_token)
text = [re.sub(pattern, lambda _: next(video_replacements), prompt) for prompt in text]
# Process audio inputs
audio_inputs = {}
if audio is not None:
if self.audio_token is None or self.boa_token is None or self.eoa_token is None:
raise ValueError(
"Audio inputs were provided, but the tokenizer does not have an `audio_token` defined."
)
# Normalize audio input to list of waveforms
if isinstance(audio, np.ndarray) and audio.ndim == 1:
audio = [audio]
# TODO: Add tests for audio-only processor inputs.
if not text:
text = [self.audio_token] * len(audio)
# Dynamic audio token expansion wihtout padding:
# * Extract audio features with feature extractor;
# * Compute precise per-audio token counts from the waveform duration;
# * Generate full audio token sequence for each computed audio length;
# * Expand text prompts with full audio token sequences.
audio_kwargs = output_kwargs.get("audio_kwargs", {})
audio_inputs = self.feature_extractor(audio, **audio_kwargs)
sampling_rate = self.feature_extractor.sampling_rate
num_audio_tokens = [self._compute_audio_num_tokens(a, sampling_rate) for a in audio]
replacements = [f"{self.boa_token}{self.audio_token * n}{self.eoa_token}" for n in num_audio_tokens]
replacements_iter = iter(replacements)
audio_pattern = re.escape(self.audio_token)
text = [re.sub(audio_pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
# Check special tokens for all active modalities
active_modalities = []
if images is not None:
active_modalities.append("image")
if videos is not None:
active_modalities.append("video")
if audio is not None:
active_modalities.append("audio")
if active_modalities:
self._check_special_mm_tokens(text, text_inputs, modalities=active_modalities)
if return_mm_token_type_ids:
text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
return BatchFeature(
data={**text_inputs, **image_inputs, **audio_inputs, **video_inputs},
tensor_type=return_tensors,
)
def _compute_audio_num_tokens(self, audio_waveform, sampling_rate: int) -> int:
"""Compute the number of audio soft tokens for a single waveform.
Replicates the exact sequence-length arithmetic of the audio encoder
so that the processor inserts the correct number of placeholder tokens.
The computation mirrors:
1. Mel framing via ``_unfold`` in ``Gemma4AudioFeatureExtractor``
2. Two ``Conv2d`` subsampling layers in ``Gemma4AudioSubSampleConvProjection``
(each: kernel=3, stride=2, semicausal padding top=1, bottom=1)
The result is capped at ``self.audio_seq_length`` (the configured maximum).
Args:
audio_waveform: A 1-D numpy array or list containing the raw audio samples.
sampling_rate: The sampling rate of the audio waveform in Hz.
Returns:
The number of audio soft tokens to insert as placeholders.
"""
num_samples = len(audio_waveform)
# Step 1: Mel frames (matches feature_extraction_gemma4.py _unfold)
frame_length = int(round(sampling_rate * 20.0 / 1000.0)) # 320 @ 16kHz
hop_length = int(round(sampling_rate * 10.0 / 1000.0)) # 160 @ 16kHz
frame_size_for_unfold = frame_length + 1 # 321
# The feature extractor prepends (frame_length // 2) zero samples as
# semicausal time-padding before the unfold. We must include this to
# match the actual number of mel frames it produces.
pad_left = frame_length // 2 # 160 @ 16kHz
padded_samples = num_samples + pad_left
num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1
if num_mel_frames <= 0:
return 0
# Step 2: Two SSCP conv layers (kernel=3, stride=2, semicausal pad top=1, bottom=1)
# Each layer: T_out = (T_in + pad_top + pad_bottom - kernel) // stride + 1
t = num_mel_frames
for _ in range(2):
t_padded = t + 2 # pad_top=1, pad_bottom=1
t = (t_padded - 3) // 2 + 1
# Cap at the configured maximum
return min(t, self.audio_seq_length)
def _get_num_multimodal_tokens(self, image_sizes=None, audio_lengths=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`list[list[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
audio_lengths (`list[int]`, *optional*):
The lengths of audio inputs in number of samples. Used to dynamically
compute per-audio token counts.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
images_kwargs = Gemma4ProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
patch_size = images_kwargs.get("patch_size", None) or self.image_processor.patch_size
pooling_kernel_size = (
images_kwargs.get("pooling_kernel_size", None) or self.image_processor.pooling_kernel_size
)
max_soft_tokens = images_kwargs.get("max_soft_tokens", None) or self.image_processor.max_soft_tokens
max_patches = max_soft_tokens * pooling_kernel_size**2
vision_data = {}
if image_sizes is not None:
num_image_tokens = []
for image_size in image_sizes:
target_h, target_w = get_aspect_ratio_preserving_size(
height=image_size[0],
width=image_size[1],
patch_size=patch_size,
max_patches=max_patches,
pooling_kernel_size=pooling_kernel_size,
)
patch_height = target_h // patch_size
patch_width = target_w // patch_size
num_image_tokens.append(patch_height * patch_width // pooling_kernel_size**2)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
if audio_lengths is not None:
# Dynamically compute per-audio token counts from sample lengths.
# audio_lengths are in number of samples; assume default sampling rate.
sampling_rate = getattr(self.feature_extractor, "sampling_rate", 16_000)
num_audio_tokens = [
self._compute_audio_num_tokens(np.zeros(length), sampling_rate) for length in audio_lengths
]
vision_data.update({"num_audio_tokens": num_audio_tokens})
return MultiModalData(**vision_data)
@property
def model_input_names(self):
model_input_names = super().model_input_names
model_input_names = [
name
for name in model_input_names
if name not in ["num_soft_tokens_per_image", "num_soft_tokens_per_video"]
]
# Include audio feature extractor input names if available
if self.feature_extractor is not None:
feature_extractor_input_names = self.feature_extractor.model_input_names
model_input_names.extend([name for name in feature_extractor_input_names if name not in model_input_names])
return model_input_names + ["mm_token_type_ids"]
__all__ = ["Gemma4Processor"]
@@ -0,0 +1,237 @@
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...image_processing_utils import BatchFeature
from ...processing_utils import Unpack, VideosKwargs
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
logging,
)
from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
from ...video_utils import VideoInput
from .image_processing_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size
if is_vision_available():
from ...image_utils import PILImageResampling
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
elif is_torchvision_available():
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
class Gemma4VideoProcessorKwargs(VideosKwargs, total=False):
"""
patch_size (`int`, *optional*):
Size of each image patch in pixels.
max_soft_tokens (`int`, *optional*):
Maximum number of soft (vision) tokens per video frame.
Must be one of {70, 140, 280, 560, 1120}.
pooling_kernel_size (`int`, *optional*):
Spatial pooling kernel size applied after patchification.
"""
patch_size: int
max_soft_tokens: int
pooling_kernel_size: int
def convert_video_to_patches(video: "torch.Tensor", patch_size: int) -> "torch.Tensor":
"""
Convert 4D tensor video of shape (num_frames, num_channels, height, width) into 3D tensor of patches of shape
(num_frames, num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
"""
num_frames, num_channels, height, width = video.shape
num_patches_height = height // patch_size
num_patches_width = width // patch_size
patched_video = video.reshape(
num_frames, num_channels, num_patches_height, patch_size, num_patches_width, patch_size
)
patched_video = patched_video.permute(0, 2, 4, 3, 5, 1)
patched_video = patched_video.reshape(num_frames, num_patches_height * num_patches_width, -1)
return patched_video
def pad_to_max_patches(
video: "torch.Tensor", positions: "torch.Tensor", target_length: int
) -> tuple["torch.Tensor", "torch.Tensor"]:
"""
Pad the video along to max number of patches
"""
current_length = video.shape[1]
padding_length = target_length - current_length
if padding_length > 0:
padding = [0, 0, 0, padding_length, 0, 0]
pos_padding = (0, 0, 0, padding_length, 0, 0)
video = torch.nn.functional.pad(video, padding, mode="constant", value=0)
positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1)
return video, positions
@add_start_docstrings(
"Constructs a Gemma4 video processor that samples frames from videos for use with the Gemma4 model.",
BASE_VIDEO_PROCESSOR_DOCSTRING,
)
class Gemma4VideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BICUBIC
image_mean = [0.0, 0.0, 0.0]
image_std = [1.0, 1.0, 1.0]
size = None
default_to_square = True
do_convert_rgb = True
do_resize = True
do_rescale = True
do_normalize = True
num_frames = 32
do_sample_frames = True
patch_size = 16
max_soft_tokens = 70
pooling_kernel_size = 3
valid_kwargs = Gemma4VideoProcessorKwargs
model_input_names = ["pixel_values_videos", "video_position_ids"]
def __init__(self, **kwargs: Unpack[Gemma4VideoProcessorKwargs]):
super().__init__(**kwargs)
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
def _validate_preprocess_kwargs(self, **kwargs):
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
# parameter. Temporarily disable do_resize so the base validation
# doesn't require `size` to be set.
kwargs["do_resize"] = False
super()._validate_preprocess_kwargs(**kwargs)
def aspect_ratio_preserving_resize(
self,
video: torch.Tensor,
patch_size: int,
max_patches: int,
pooling_kernel_size: int,
resample: F.InterpolationMode,
) -> torch.Tensor:
height, width = video.shape[-2], video.shape[-1]
target_height, target_width = get_aspect_ratio_preserving_size(
height=height,
width=width,
patch_size=patch_size,
max_patches=max_patches,
pooling_kernel_size=pooling_kernel_size,
)
if target_height == height and target_width == width:
return video
return F.resize(
video,
size=[target_height, target_width],
interpolation=resample,
antialias=True,
)
def preprocess(
self,
videos: VideoInput,
**kwargs: Unpack[Gemma4VideoProcessorKwargs],
) -> BatchFeature:
return super().preprocess(videos, **kwargs)
def _preprocess(
self,
videos: list["torch.Tensor"],
do_resize: bool,
resample: "F.InterpolationMode | int | None",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
return_tensors: str | TensorType | None,
patch_size: int | None = None,
max_soft_tokens: int | None = None,
pooling_kernel_size: int | None = None,
**kwargs,
) -> BatchFeature:
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
max_patches = max_soft_tokens * pooling_kernel_size**2
pixel_values = []
position_ids = []
num_soft_tokens_per_video = []
num_frames = 1
for video in videos:
if do_resize:
video = self.aspect_ratio_preserving_resize(
video=video,
patch_size=patch_size,
max_patches=max_patches,
pooling_kernel_size=pooling_kernel_size,
resample=resample,
)
video = self.rescale_and_normalize(video, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
num_frames = video.shape[0]
patch_height = video.shape[-2] // patch_size
patch_width = video.shape[-1] // patch_size
patches = convert_video_to_patches(video, patch_size)
num_soft_tokens_per_video.append(patches.shape[1] // pooling_kernel_size**2)
device = video.device
patch_grid = torch.meshgrid(
torch.arange(patch_width, device=device),
torch.arange(patch_height, device=device),
indexing="xy",
)
stacked_grid = torch.stack(patch_grid, dim=-1)
real_positions = stacked_grid.reshape(patches.shape[1], 2)
real_positions = real_positions[None, ...].repeat(num_frames, 1, 1)
patches, positions = pad_to_max_patches(patches, real_positions, max_patches)
pixel_values.append(patches)
position_ids.append(positions)
# Stack into batch tensors
pixel_values = torch.stack(pixel_values, dim=0) # (num_videos, num_frames, max_patches, patch_pixels)
position_ids = torch.stack(position_ids, dim=0) # (num_videos, num_frames, max_patches, 2)
data = {
"pixel_values_videos": pixel_values,
"video_position_ids": position_ids,
"num_soft_tokens_per_video": num_soft_tokens_per_video,
}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma4VideoProcessor"]