eecebe7ef5
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
165 lines
4.8 KiB
Python
165 lines
4.8 KiB
Python
# Copyright 2026 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
r"""Example of Gemma finetuning for an image captioning task.
|
|
|
|
Example:
|
|
|
|
Prompt:
|
|
|
|
```
|
|
<start_of_turn>user
|
|
<start_of_image><end_of_turn>
|
|
<start_of_turn>model
|
|
```
|
|
|
|
Target:
|
|
|
|
```
|
|
A diagram showing a circuit with a battery, lamp, and switch.<end_of_turn>
|
|
```
|
|
|
|
Here, the prompt only contains the `<start_of_image>` to indicate an image
|
|
is inserted.
|
|
|
|
Train locally with:
|
|
|
|
```sh
|
|
python -m kauldron.main \
|
|
--cfg=examples/multimodal.py \
|
|
--cfg.workdir=/tmp/kauldron_oss/workdir
|
|
```
|
|
|
|
"""
|
|
|
|
from kauldron import konfig
|
|
|
|
# pylint: disable=g-import-not-at-top
|
|
with konfig.imports():
|
|
import jax.numpy as jnp
|
|
from gemma import gm
|
|
from kauldron import kd
|
|
import optax
|
|
# pylint: enable=g-import-not-at-top
|
|
|
|
|
|
def get_config():
|
|
batch_size = 32
|
|
max_length = 200
|
|
|
|
return kd.train.Trainer(
|
|
seed=42,
|
|
# Dataset
|
|
train_ds=_make_dataset(
|
|
training=True,
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
),
|
|
# Model definition
|
|
model=gm.nn.Gemma3_4B(
|
|
tokens="batch.input",
|
|
images="batch.image",
|
|
),
|
|
# Load the weights from the pretrained checkpoint
|
|
init_transform=gm.ckpts.LoadCheckpoint(
|
|
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
|
),
|
|
# Training
|
|
num_train_steps=10_000,
|
|
train_losses={
|
|
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
|
logits="preds.logits",
|
|
labels="batch.target",
|
|
mask="batch.loss_mask",
|
|
),
|
|
},
|
|
train_summaries={
|
|
"image": kd.summaries.ShowImages(images="batch.image", num_images=5),
|
|
},
|
|
optimizer=optax.adafactor(learning_rate=1e-3),
|
|
checkpointer=kd.ckpts.Checkpointer(
|
|
save_interval_steps=500,
|
|
),
|
|
# Evaluation
|
|
evals={
|
|
"test": kd.evals.Evaluator(
|
|
run=kd.evals.EveryNSteps(1000),
|
|
ds=_make_dataset(
|
|
training=False,
|
|
batch_size=4,
|
|
max_length=max_length,
|
|
),
|
|
),
|
|
# The sampler evaluator run inference on a few prompts from the
|
|
# test set.
|
|
"sampling": gm.evals.SamplerEvaluator(
|
|
run=kd.evals.EveryNSteps(1000),
|
|
max_new_tokens=50, # Sampling parameters
|
|
num_batches=3,
|
|
ds=_make_dataset(training=False, sampling=True),
|
|
summaries={
|
|
"image": kd.summaries.ShowImages(
|
|
images="batch.image", num_images=5
|
|
),
|
|
},
|
|
),
|
|
},
|
|
)
|
|
|
|
|
|
def _make_dataset(
|
|
*,
|
|
training: bool,
|
|
sampling: bool = False,
|
|
batch_size: int | None = None,
|
|
max_length: int | None = None,
|
|
):
|
|
tokenizer = gm.text.Gemma3Tokenizer()
|
|
|
|
return kd.data.py.Tfds(
|
|
name="ai2dcaption",
|
|
split="llava_15" if training else "test",
|
|
shuffle=True if training else False,
|
|
num_epochs=None if training else 1,
|
|
batch_size=None if sampling else batch_size,
|
|
num_workers=4,
|
|
transforms=[
|
|
# Only keep the fields we need.See fields at:
|
|
# https://www.tensorflow.org/datasets/catalog/ai2dcaption
|
|
kd.data.Elements(keep=["image", "caption"]),
|
|
# Create a new constant field
|
|
kd.data.AddConstants({"prompt": "<start_of_image>"}),
|
|
# Create the model inputs/targets/loss_mask.
|
|
gm.data.Seq2SeqTask(
|
|
# Select which field from the dataset to use.
|
|
in_prompt="prompt",
|
|
in_response="caption",
|
|
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
|
|
out_input="input",
|
|
out_target="target",
|
|
out_target_mask="loss_mask",
|
|
tokenizer=tokenizer,
|
|
# Padding parameters
|
|
max_length=None if sampling else max_length,
|
|
# In this dataset, ~1% of examples are longer than 512 tokens.
|
|
truncate=True,
|
|
sampling=sampling,
|
|
),
|
|
kd.data.py.Resize(key="image", size=(800, 800)),
|
|
# TODO(epot): Make the `num_images` dimension optional
|
|
kd.data.Rearrange(key="image", pattern="... h w c -> ... 1 h w c"),
|
|
kd.data.Cast(key="image", dtype=jnp.uint8),
|
|
],
|
|
)
|