docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
|
||||
|
||||
|
||||
def collate_fn(examples, processor):
|
||||
messages = list()
|
||||
for sample in examples:
|
||||
image = sample["image"].convert("RGB")
|
||||
label = str(sample["label"])
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
batch = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask the tokens that we do not want to include in the loss computation
|
||||
# -100 is ignored during categorical cross entropy loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == processor.tokenizer.image_token_id] = -100
|
||||
labels[labels == processor.tokenizer.boi_token_id] = -100
|
||||
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def freeze_layers(model):
|
||||
for name, param in model.named_parameters():
|
||||
if "attn" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def run_inference(val_dataset, processor, model, fname):
|
||||
# infer before training
|
||||
val_sample = random.choice(val_dataset)
|
||||
image = val_sample["image"].convert("RGB")
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
with torch.no_grad():
|
||||
generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
|
||||
plt.imshow(image)
|
||||
plt.axis("off")
|
||||
plt.title(f"Pred: {decoded}")
|
||||
plt.show()
|
||||
plt.savefig(f"outputs_fine_tune/{fname}")
|
||||
|
||||
|
||||
def main():
|
||||
model_id = "google/gemma-3n-E2B-it"
|
||||
processor = Gemma3nProcessor.from_pretrained(model_id)
|
||||
|
||||
# load the dataset
|
||||
dataset_id = "ariG23498/intersection-dataset"
|
||||
train_dataset = load_dataset(dataset_id, split="train")
|
||||
val_dataset = load_dataset(dataset_id, split="validation")
|
||||
|
||||
# create data loader
|
||||
partial_collate_fn = partial(collate_fn, processor=processor)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
)
|
||||
|
||||
# load the model and optimizer
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda")
|
||||
|
||||
run_inference(val_dataset, processor, model, "pred_before.png")
|
||||
|
||||
model = freeze_layers(model)
|
||||
|
||||
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
|
||||
|
||||
# Start Training
|
||||
accumulation_steps = 8
|
||||
for idx, batch in tqdm(enumerate(train_dataloader)):
|
||||
outputs = model(**batch.to(model.device))
|
||||
loss = outputs.loss / accumulation_steps
|
||||
if idx % 50 == 0:
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
count = 0
|
||||
for val_batch in val_dataloader:
|
||||
val_loss = val_loss + model(**val_batch.to(model.device)).loss
|
||||
count = count + 1
|
||||
val_loss = val_loss / count
|
||||
print(
|
||||
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
|
||||
)
|
||||
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
|
||||
|
||||
loss.backward()
|
||||
if idx % 8 == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user