docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,526 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "colab-badge"
|
||||
},
|
||||
"source": [
|
||||
"<table align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/apps/Gemma_4_HDP_Agentic_Security/Gemma_4_HDP_Agentic_Security.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "byline"
|
||||
},
|
||||
"source": [
|
||||
"# Securing Gemma 4 Agentic Workflows with HDP\n",
|
||||
"\n",
|
||||
"**Author:** Asiri Dalugoda, Helixar Limited ([@asiridalugoda](https://github.com/asiridalugoda)) | [helixar.ai](https://helixar.ai)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gpu-instructions"
|
||||
},
|
||||
"source": [
|
||||
"## Before you begin\n",
|
||||
"\n",
|
||||
"This notebook requires a GPU runtime. To enable GPU in Colab:\n",
|
||||
"1. Go to **Runtime → Change runtime type**\n",
|
||||
"2. Set **Hardware accelerator** to **GPU** (T4 is sufficient for E4B)\n",
|
||||
"3. Click **Save**\n",
|
||||
"\n",
|
||||
"You will also need a **Hugging Face token** to download Gemma 4 (gated model):\n",
|
||||
"1. Go to [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
|
||||
"2. Create a token with **Read** access\n",
|
||||
"3. Accept the Gemma 4 model license at [huggingface.co/google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)\n",
|
||||
"4. Run the cell below to authenticate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hf-login"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "overview"
|
||||
},
|
||||
"source": [
|
||||
"# Securing Gemma 4 Agentic Workflows with HDP\n",
|
||||
"\n",
|
||||
"**Human Delegation Provenance (HDP)** is an open protocol that adds a cryptographic chain-of-custody to AI agent function calls — ensuring every tool invocation can be traced back to an authorized human principal.\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to integrate HDP with Gemma 4's native function-calling capability to:\n",
|
||||
"\n",
|
||||
"- **Verify** that Gemma 4's function calls were authorized by a human principal before execution\n",
|
||||
"- **Classify** actions by irreversibility (read-only → irreversible → physical actuation)\n",
|
||||
"- **Block** unauthorized or out-of-scope tool calls at the middleware layer\n",
|
||||
"- **Audit** every decision with a pre-execution log\n",
|
||||
"\n",
|
||||
"This is particularly relevant for Gemma 4 deployments on edge devices (Jetson Nano, Raspberry Pi) where the model may be directing physical actuators offline with no out-of-band authorization check.\n",
|
||||
"\n",
|
||||
"**References:**\n",
|
||||
"- HDP IETF draft: [draft-helixar-hdp-agentic-delegation-00](https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/)\n",
|
||||
"- HDP-P (physical AI agents): [DOI 10.5281/ZENODO.19332440](https://doi.org/10.5281/ZENODO.19332440)\n",
|
||||
"- Helixar: [helixar.ai](https://helixar.ai)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b3600ee25c8e"
|
||||
},
|
||||
"source": [
|
||||
"## Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "7a80251f52b3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q transformers torch cryptography"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ed80fe18f255"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Download the middleware\n",
|
||||
"!wget -q https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/apps/Gemma_4_HDP_Agentic_Security/hdp_middleware.py\n",
|
||||
"\n",
|
||||
"from hdp_middleware import (\n",
|
||||
" HDPDelegationToken,\n",
|
||||
" HDPMiddleware,\n",
|
||||
" IrreversibilityClass,\n",
|
||||
" DEFAULT_TOOL_CLASS_MAP,\n",
|
||||
")\n",
|
||||
"from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey\n",
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e88bdc7b7265"
|
||||
},
|
||||
"source": [
|
||||
"## 1. Load Gemma 4\n",
|
||||
"\n",
|
||||
"We use the 4B Effective model for this demo. For production agentic deployments, the 26B MoE or 31B Dense models are recommended."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1e4e7779806d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"# For edge/robotics use cases: swap to google/gemma-4-E2B-it\n",
|
||||
"MODEL_ID = \"google/gemma-4-E4B-it\"\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\n",
|
||||
" \"text-generation\",\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "d91e36cfb0b2"
|
||||
},
|
||||
"source": [
|
||||
"## 2. Define Tools\n",
|
||||
"\n",
|
||||
"Gemma 4 uses structured JSON function-calling. We define a tool set spanning different IrreversibilityClasses to demonstrate the middleware's classification behaviour."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1becdb52e7f8"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TOOLS = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"get_weather\",\n",
|
||||
" \"description\": \"Get the current weather for a location.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\"type\": \"string\", \"description\": \"City name\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"send_email\",\n",
|
||||
" \"description\": \"Send an email to a recipient.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"to\": {\"type\": \"string\"},\n",
|
||||
" \"subject\": {\"type\": \"string\"},\n",
|
||||
" \"body\": {\"type\": \"string\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"to\", \"subject\", \"body\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"delete_file\",\n",
|
||||
" \"description\": \"Permanently delete a file by path.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"path\": {\"type\": \"string\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"path\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"actuate_robot_arm\",\n",
|
||||
" \"description\": \"Command a robot arm to move to a target position.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"joint_angles\": {\"type\": \"array\", \"items\": {\"type\": \"number\"}},\n",
|
||||
" \"force_limit_n\": {\"type\": \"number\"}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"joint_angles\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Tools indexed by name for lookup\n",
|
||||
"TOOL_REGISTRY = {t[\"name\"]: t for t in TOOLS}\n",
|
||||
"print(f\"Registered {len(TOOLS)} tools\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "722948b00a92"
|
||||
},
|
||||
"source": [
|
||||
"## 3. Issue an HDP Delegation Token\n",
|
||||
"\n",
|
||||
"The human principal generates an Ed25519 keypair and issues an HDT that specifies:\n",
|
||||
"- Which tools the agent is permitted to call\n",
|
||||
"- The maximum IrreversibilityClass the agent can act on\n",
|
||||
"- The token's lifetime"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "b0622c68dfa5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Human principal generates their signing keypair\n",
|
||||
"# In production: loaded from secure key storage (HSM, OS keychain, etc.)\n",
|
||||
"principal_private_key = Ed25519PrivateKey.generate()\n",
|
||||
"principal_public_key = principal_private_key.public_key()\n",
|
||||
"\n",
|
||||
"# Issue an HDT authorizing the Gemma 4 agent to call weather queries\n",
|
||||
"# and send emails (Class 0 and Class 2), but NOT delete files or actuate hardware\n",
|
||||
"token = HDPDelegationToken.issue(\n",
|
||||
" principal_id=\"alice@example.com\",\n",
|
||||
" agent_id=\"gemma4-agent-01\",\n",
|
||||
" scope=[\"get_weather\", \"send_email\"],\n",
|
||||
" max_class=IrreversibilityClass.CLASS_2,\n",
|
||||
" ttl_seconds=3600,\n",
|
||||
" private_key=principal_private_key,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(json.dumps(token.to_dict(), indent=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e206f950f4bc"
|
||||
},
|
||||
"source": [
|
||||
"## 4. Initialise the HDP Middleware\n",
|
||||
"\n",
|
||||
"The middleware takes the principal's **public key** only — it verifies but cannot issue tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "e24676f528bf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"audit_log = []\n",
|
||||
"\n",
|
||||
"# Confirmation callback for Class 2 (irreversible) actions.\n",
|
||||
"# In production: this would invoke a push notification, SMS OTP,\n",
|
||||
"# or hardware confirmation device to the human principal.\n",
|
||||
"def require_human_confirmation(tool_name: str, parameters: dict) -> bool:\n",
|
||||
" print(f\"\\n⚠️ Class 2 action requested: {tool_name}\")\n",
|
||||
" print(f\" Parameters: {json.dumps(parameters, indent=4)}\")\n",
|
||||
" response = input(\" Confirm? [y/N]: \").strip().lower()\n",
|
||||
" return response == \"y\"\n",
|
||||
"\n",
|
||||
"middleware = HDPMiddleware(\n",
|
||||
" public_key=principal_public_key,\n",
|
||||
" tool_class_map=DEFAULT_TOOL_CLASS_MAP,\n",
|
||||
" confirmation_callback=require_human_confirmation,\n",
|
||||
" audit_log=audit_log,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"HDP middleware initialised.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "72d56542eba0"
|
||||
},
|
||||
"source": [
|
||||
"## 5. Gemma 4 Function Call → HDP Gate → Tool Execution\n",
|
||||
"\n",
|
||||
"This is the core integration pattern. Every function call Gemma 4 generates is passed through `middleware.gate()` before being forwarded to tool execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "da20bc191e71"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Simulated Gemma 4 function call outputs\n",
|
||||
"# In production these come from parsing Gemma 4's structured JSON output\n",
|
||||
"gemma_function_calls = [\n",
|
||||
" # ✅ Should ALLOW — Class 0, in scope\n",
|
||||
" {\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}},\n",
|
||||
"\n",
|
||||
" # ⚠️ Should CONFIRM then ALLOW — Class 2, in scope\n",
|
||||
" {\"name\": \"send_email\", \"parameters\": {\n",
|
||||
" \"to\": \"bob@example.com\",\n",
|
||||
" \"subject\": \"Weekly report\",\n",
|
||||
" \"body\": \"Please find attached.\"\n",
|
||||
" }},\n",
|
||||
"\n",
|
||||
" # ❌ Should BLOCK — Class 2, NOT in HDT scope\n",
|
||||
" {\"name\": \"delete_file\", \"parameters\": {\"path\": \"/data/important.csv\"}},\n",
|
||||
"\n",
|
||||
" # ❌ Should BLOCK — Class 3, physical actuation\n",
|
||||
" {\"name\": \"actuate_robot_arm\", \"parameters\": {\n",
|
||||
" \"joint_angles\": [0.0, -1.57, 0.0, -1.57, 0.0, 0.0],\n",
|
||||
" \"force_limit_n\": 50.0\n",
|
||||
" }},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(\"HDP VERIFICATION RESULTS\")\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
"for call in gemma_function_calls:\n",
|
||||
" result = middleware.gate(call, token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "be0d0dd05bce"
|
||||
},
|
||||
"source": [
|
||||
"## 6. Audit Log\n",
|
||||
"\n",
|
||||
"Every decision is logged pre-execution. This is the HDP audit trail — a cryptographically linked record of what was authorized, by whom, and when."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "e6dbab6d88d1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"\\nAUDIT LOG\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"for i, entry in enumerate(audit_log):\n",
|
||||
" status = \"✅ ALLOWED\" if entry.allowed else \"❌ BLOCKED\"\n",
|
||||
" print(f\"{i+1}. {status} | {entry.tool_name} | {entry.action_class.name} | {entry.reason}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bcadcb7040db"
|
||||
},
|
||||
"source": [
|
||||
"## 7. Token Expiry and Scope Violation Demo\n",
|
||||
"\n",
|
||||
"Demonstrate that expired tokens and out-of-scope calls are blocked regardless of the action class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "deb2e3b6b20e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# Issue a token that's already expired\n",
|
||||
"expired_token = HDPDelegationToken.issue(\n",
|
||||
" principal_id=\"alice@example.com\",\n",
|
||||
" agent_id=\"gemma4-agent-01\",\n",
|
||||
" scope=[\"get_weather\"],\n",
|
||||
" max_class=IrreversibilityClass.CLASS_0,\n",
|
||||
" ttl_seconds=-1, # expired immediately\n",
|
||||
" private_key=principal_private_key,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Testing expired token:\")\n",
|
||||
"middleware.gate({\"name\": \"get_weather\", \"parameters\": {\"location\": \"Auckland\"}}, expired_token)\n",
|
||||
"\n",
|
||||
"print(\"\\nTesting call outside HDT scope:\")\n",
|
||||
"middleware.gate({\"name\": \"delete_file\", \"parameters\": {\"path\": \"/etc/passwd\"}}, token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b8f4acddb6fa"
|
||||
},
|
||||
"source": [
|
||||
"## 8. Edge / Robotics Deployment (HDP-P)\n",
|
||||
"\n",
|
||||
"For Gemma 4 E2B/E4B running on Jetson Nano or Raspberry Pi and directing physical actuators, use the HDP-P extension. The key additions are:\n",
|
||||
"\n",
|
||||
"- **Embodiment context** — bind the token to a specific hardware ID\n",
|
||||
"- **Policy attestation** — hash the deployed model weights into the token\n",
|
||||
"- **Fleet delegation constraints** — prevent lateral movement across robot fleet\n",
|
||||
"- **Pre-execution logging** — write audit records *before* actuator commands are issued\n",
|
||||
"\n",
|
||||
"See the [HDP-P specification](https://doi.org/10.5281/ZENODO.19332440) for the full EDT extension structure."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fcf7b451d175"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Minimal HDP-P Embodied Delegation Token (EDT) extension example\n",
|
||||
"# This shows how to attach physical constraints to an HDT\n",
|
||||
"\n",
|
||||
"hdp_p_extension = {\n",
|
||||
" \"hdp-p\": {\n",
|
||||
" \"version\": \"0.1\",\n",
|
||||
" \"embodiment\": {\n",
|
||||
" \"type\": \"mobile\",\n",
|
||||
" \"platform\": \"raspberry-pi-5\",\n",
|
||||
" \"hardware_id\": \"rpi-serial-XXXX\", # TPM-attested in production\n",
|
||||
" \"workspace\": \"lab-zone-a\"\n",
|
||||
" },\n",
|
||||
" \"action_scope\": {\n",
|
||||
" \"permitted_actions\": [\"move_base\", \"read_sensor\"],\n",
|
||||
" \"excluded_zones\": [\"human-workspace\"],\n",
|
||||
" \"force_limit_n\": 10.0,\n",
|
||||
" \"max_velocity_ms\": 0.5\n",
|
||||
" },\n",
|
||||
" \"irreversibility\": {\n",
|
||||
" \"max_class\": 1, # Class 1 max for this token\n",
|
||||
" \"class2_requires_confirmation\": True,\n",
|
||||
" \"class3_prohibited\": True\n",
|
||||
" },\n",
|
||||
" \"policy_attestation\": {\n",
|
||||
" \"policy_hash\": \"sha256:abc123...\", # SHA-256 of deployed model weights\n",
|
||||
" \"training_run_id\": \"gemma4-e2b-it\",\n",
|
||||
" \"sim_validated\": True\n",
|
||||
" },\n",
|
||||
" \"delegation_scope\": {\n",
|
||||
" \"fleet_delegation_permitted\": False, # No lateral movement\n",
|
||||
" \"max_delegation_depth\": 0\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"print(\"HDP-P EDT extension structure:\")\n",
|
||||
"print(json.dumps(hdp_p_extension, indent=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b0af7c701dfc"
|
||||
},
|
||||
"source": [
|
||||
"## Summary\n",
|
||||
"\n",
|
||||
"| Layer | What it solves | Tool |\n",
|
||||
"|---|---|---|\n",
|
||||
"| Gemma 4 function calling | Model generates structured tool calls | `pipeline(\"text-generation\")` |\n",
|
||||
"| HDP middleware | Was this call authorized by a human? | `HDPMiddleware.gate()` |\n",
|
||||
"| HDP-P EDT extension | Is this physical action within delegated bounds? | `hdp_p_extension` |\n",
|
||||
"| Audit log | Pre-execution record of every decision | `audit_log` |\n",
|
||||
"\n",
|
||||
"The full HDP specification (IETF draft), HDP-P companion paper, TypeScript SDK, and Python bindings are available at:\n",
|
||||
"\n",
|
||||
"- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/\n",
|
||||
"- **HDP-P paper:** https://doi.org/10.5281/ZENODO.19332440\n",
|
||||
"- **GitHub:** https://github.com/Helixar-AI\n",
|
||||
"- **Site:** https://helixar.ai"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "Gemma_4_HDP_Agentic_Security.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
# Gemma 4 + HDP: Securing Agentic Function Calls
|
||||
|
||||
This example demonstrates how to integrate the **Human Delegation Provenance (HDP)** protocol with **Gemma 4's native function-calling** to cryptographically verify that every tool invocation was authorized by a human principal before execution.
|
||||
|
||||
## The problem
|
||||
|
||||
Gemma 4 is purpose-built for agentic workflows. Its native function-calling lets it autonomously call tools and APIs across multi-step plans — on anything from a cloud workstation to a Raspberry Pi running a robot offline.
|
||||
|
||||
This creates a gap: when Gemma 4 generates a function call, there is no verifiable record that a human principal authorized that specific action. An injected prompt, a compromised system prompt, or a lateral pivot from another agent can trigger function calls that are indistinguishable from legitimate requests at the tool interface.
|
||||
|
||||
HDP closes this gap.
|
||||
|
||||
## What HDP does
|
||||
|
||||
HDP (IETF draft: `draft-helixar-hdp-agentic-delegation-00`) provides:
|
||||
|
||||
- **Ed25519-signed Delegation Tokens (HDTs)** issued by a human principal
|
||||
- **Scope constraints** — which tools the agent is permitted to call
|
||||
- **Irreversibility classification** (Class 0–3) — from read-only to physical actuation
|
||||
- **Pre-execution verification** — the middleware gate runs *before* any tool executes
|
||||
- **Audit log** — a tamper-evident record of every authorization decision
|
||||
|
||||
For Gemma 4 on **edge devices directing physical actuators** (Jetson Nano, Raspberry Pi + robot arm), the HDP-P companion specification adds embodiment constraints, policy attestation, and fleet delegation controls.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| `Gemma_4_HDP_Agentic_Security.ipynb` | Full walkthrough notebook — load Gemma 4, issue tokens, gate function calls |
|
||||
| `hdp_middleware.py` | Drop-in middleware — `HDPMiddleware.gate()` wraps any Gemma 4 tool executor |
|
||||
|
||||
## Quick start
|
||||
|
||||
```python
|
||||
from hdp_middleware import HDPDelegationToken, HDPMiddleware, IrreversibilityClass
|
||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
|
||||
|
||||
# Human principal issues a delegation token
|
||||
private_key = Ed25519PrivateKey.generate()
|
||||
token = HDPDelegationToken.issue(
|
||||
principal_id="alice@example.com",
|
||||
agent_id="gemma4-agent-01",
|
||||
scope=["get_weather", "send_email"],
|
||||
max_class=IrreversibilityClass.CLASS_2,
|
||||
ttl_seconds=3600,
|
||||
private_key=private_key,
|
||||
)
|
||||
|
||||
# Middleware verifies every Gemma 4 function call before execution
|
||||
middleware = HDPMiddleware(public_key=private_key.public_key())
|
||||
|
||||
result = middleware.gate(
|
||||
function_call={"name": "send_email", "parameters": {"to": "bob@example.com", ...}},
|
||||
token=token,
|
||||
)
|
||||
|
||||
if result.allowed:
|
||||
execute_tool(function_call)
|
||||
```
|
||||
|
||||
## Irreversibility classes
|
||||
|
||||
| Class | Definition | Authorization |
|
||||
|---|---|---|
|
||||
| 0 | Fully reversible — reads, queries | HDT sufficient |
|
||||
| 1 | Reversible with effort — writes, moves | HDT sufficient |
|
||||
| 2 | Irreversible — send, delete, publish | HDT + principal confirmation |
|
||||
| 3 | Irreversible + potentially harmful — physical actuation | Dual-principal required (HDP-P) |
|
||||
|
||||
## References
|
||||
|
||||
- **IETF draft:** https://datatracker.ietf.org/doc/draft-helixar-hdp-agentic-delegation/
|
||||
- **Zenodo DOI:** https://doi.org/10.5281/zenodo.19332023
|
||||
- **HDP-P (physical AI):** https://doi.org/10.5281/ZENODO.19332440
|
||||
- **Helixar:** https://helixar.ai
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -0,0 +1,980 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "pn1797sn9Jb_"
|
||||
},
|
||||
"source": [
|
||||
"##### Copyright 2025 Google LLC."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "uivh5PY69ISg"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "O83CmJ2j9L3n"
|
||||
},
|
||||
"source": [
|
||||
"# Fine-Tune Gemma for Vision Tasks using Hugging Face Transformers and QLoRA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "f9673bd6"
|
||||
},
|
||||
"source": [
|
||||
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/docs/core/huggingface_vision_finetune_qlora.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/google-gemma/cookbook/blob/main/docs/core/huggingface_vision_finetune_qlora.ipynb\"><img src=\"https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png\" height=\"32\" width=\"70\"/>Run in Kaggle</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fgoogle-gemma%2Fcookbook%2Fmain%2Fdocs%2Fcore%2Fhuggingface_vision_finetune_qlora.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://github.com/google-gemma/cookbook/blob/main/docs/core/huggingface_vision_finetune_qlora.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "e624ec07"
|
||||
},
|
||||
"source": [
|
||||
"This guide walks you through how to fine-tune Gemma on a custom image and text dataset for a vision task (generating product descriptions) using Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) and [TRL](https://huggingface.co/docs/trl/index). You will learn:\n",
|
||||
"\n",
|
||||
"- What is Quantized Low-Rank Adaptation (QLoRA)\n",
|
||||
"- Setup development environment\n",
|
||||
"- Create and prepare the fine-tuning dataset\n",
|
||||
"- Fine-tune Gemma using TRL and the SFTTrainer\n",
|
||||
"- Test Model Inference and generate product descriptions from images and text.\n",
|
||||
"\n",
|
||||
"Note: This guide requires a GPU which support bfloat16 data type such as NVIDIA L4 or NVIDIA A100 and more than 16GB of memory.\n",
|
||||
"\n",
|
||||
"## What is Quantized Low-Rank Adaptation (QLoRA)\n",
|
||||
"\n",
|
||||
"This guide demonstrates the use of [Quantized Low-Rank Adaptation (QLoRA)](https://arxiv.org/abs/2305.14314), which emerged as a popular method to efficiently fine-tune LLMs as it reduces computational resource requirements while maintaining high performance. In QloRA, the pretrained model is quantized to 4-bit and the weights are frozen. Then trainable adapter layers (LoRA) are attached and only the adapter layers are trained. Afterwards, the adapter weights can be merged with the base model or kept as a separate adapter.\n",
|
||||
"\n",
|
||||
"## Setup development environment\n",
|
||||
"\n",
|
||||
"The first step is to install Hugging Face Libraries, including TRL, and datasets to fine-tune open model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ba51aa79"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install Pytorch & other libraries\n",
|
||||
"%pip install torch tensorboard torchvision\n",
|
||||
"\n",
|
||||
"# Install Transformers\n",
|
||||
"%pip install transformers\n",
|
||||
"\n",
|
||||
"# Install Hugging Face libraries\n",
|
||||
"%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf pillow sentencepiece\n",
|
||||
"\n",
|
||||
"# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100\n",
|
||||
"#%pip install flash-attn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7ef3d54b"
|
||||
},
|
||||
"source": [
|
||||
"_Note: If you are using a GPU with Ampere architecture (such as NVIDIA L4) or newer, you can use Flash attention. Flash Attention is a method that significantly speeds computations up and reduces memory usage from quadratic to linear in sequence length, leading to acelerating training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main)._\n",
|
||||
"\n",
|
||||
"You need a valid Hugging Face Token to publish your model. If you are running inside a Google Colab, you can securely use your Hugging Face Token using the Colab secrets otherwise you can set the token as directly in the `login` method. Make sure your token has write access too, as you push your model to the Hub during training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "b6d79c93"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Login into Hugging Face Hub\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "42c60525"
|
||||
},
|
||||
"source": [
|
||||
"## Create and prepare the fine-tuning dataset\n",
|
||||
"\n",
|
||||
"When fine-tuning LLMs, it is important to know your use case and the task you want to solve. This helps you create a dataset to fine-tune your model. If you haven't defined your use case yet, you might want to go back to the drawing board.\n",
|
||||
"\n",
|
||||
"As an example, this guide focuses on the following use case:\n",
|
||||
"\n",
|
||||
"- Fine-tuning a Gemma model to generate concise, SEO-optimized product descriptions for an ecommerce platform, specifically tailored for mobile search.\n",
|
||||
"\n",
|
||||
"This guide uses the [philschmid/amazon-product-descriptions-vlm](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm) dataset, a dataset of Amazon product descriptions, including product images and categories.\n",
|
||||
"\n",
|
||||
"Hugging Face TRL supports multimodal conversations. The important piece is the \"image\" role, which tells the processing class that it should load the image. The structure should follow:\n",
|
||||
"\n",
|
||||
"```json\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "c4ecf6db"
|
||||
},
|
||||
"source": [
|
||||
"You can now use the Hugging Face Datasets library to load the dataset and create a prompt template to combine the image, product name, and category, and add a system message. The dataset includes images as`Pil.Image` objects."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "40c3a2cf"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8d1259be3dfa4b1e899c97026276ee41",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"README.md: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a5554c0595144c949b578eb1cbdfd0fd",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00000-of-00001.parquet: 0%| | 0.00/47.6M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "9ed0567e2e4e40a88c7eddfe7d6a6e2f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating train split: 0%| | 0/1345 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'role': 'system', 'content': 'You are an expert product description writer for Amazon.'}, {'role': 'user', 'content': [{'type': 'text', 'text': \"Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\\nOnly return description. The description should be SEO optimized and for a better mobile search experience.\\n\\n<PRODUCT>\\nRazor Agitator BMX/Freestyle Bike, 20-Inch\\n</PRODUCT>\\n\\n<CATEGORY>\\nSports & Outdoors | Outdoor Recreation | Cycling | Kids' Bikes & Accessories | Kids' Bikes\\n</CATEGORY>\\n\"}, {'type': 'image', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x413 at 0x7B7250181790>}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'Conquer the streets with the Razor Agitator BMX Bike! This 20-inch freestyle bike is built for young riders ready to take on any challenge. Durable frame, responsive handling – perfect for tricks and cruising. Get yours today!'}]}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"# System message for the assistant\n",
|
||||
"system_message = \"You are an expert product description writer for Amazon.\"\n",
|
||||
"\n",
|
||||
"# User prompt that combines the user query and the schema\n",
|
||||
"user_prompt = \"\"\"Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\n",
|
||||
"Only return description. The description should be SEO optimized and for a better mobile search experience.\n",
|
||||
"\n",
|
||||
"<PRODUCT>\n",
|
||||
"{product}\n",
|
||||
"</PRODUCT>\n",
|
||||
"\n",
|
||||
"<CATEGORY>\n",
|
||||
"{category}\n",
|
||||
"</CATEGORY>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# Convert dataset to OAI messages\n",
|
||||
"def format_data(sample):\n",
|
||||
" return {\n",
|
||||
" \"messages\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" #\"content\": [{\"type\": \"text\", \"text\": system_message}],\n",
|
||||
" \"content\": system_message,\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": user_prompt.format(\n",
|
||||
" product=sample[\"Product Name\"],\n",
|
||||
" category=sample[\"Category\"],\n",
|
||||
" ),\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": sample[\"image\"],\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": [{\"type\": \"text\", \"text\": sample[\"description\"]}],\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"def process_vision_info(messages: list[dict]) -> list[Image.Image]:\n",
|
||||
" image_inputs = []\n",
|
||||
" # Iterate through each conversation\n",
|
||||
" for msg in messages:\n",
|
||||
" # Get content (ensure it's a list)\n",
|
||||
" content = msg.get(\"content\", [])\n",
|
||||
" if not isinstance(content, list):\n",
|
||||
" content = [content]\n",
|
||||
"\n",
|
||||
" # Check each content element for images\n",
|
||||
" for element in content:\n",
|
||||
" if isinstance(element, dict) and (\n",
|
||||
" \"image\" in element or element.get(\"type\") == \"image\"\n",
|
||||
" ):\n",
|
||||
" # Get the image and convert to RGB\n",
|
||||
" if \"image\" in element:\n",
|
||||
" image = element[\"image\"]\n",
|
||||
" else:\n",
|
||||
" image = element\n",
|
||||
" image_inputs.append(image.convert(\"RGB\"))\n",
|
||||
" return image_inputs\n",
|
||||
"\n",
|
||||
"# Load dataset from the hub\n",
|
||||
"dataset = load_dataset(\"philschmid/amazon-product-descriptions-vlm\", split=\"train\")\n",
|
||||
"dataset = dataset.train_test_split(test_size=0.1)\n",
|
||||
"\n",
|
||||
"# Convert dataset to OAI messages\n",
|
||||
"# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\n",
|
||||
"dataset_train = [format_data(sample) for sample in dataset[\"train\"]]\n",
|
||||
"dataset_test = [format_data(sample) for sample in dataset[\"test\"]]\n",
|
||||
"\n",
|
||||
"print(dataset_train[345][\"messages\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "c0eb2e06"
|
||||
},
|
||||
"source": [
|
||||
"## Fine-tune Gemma using TRL and the SFTTrainer\n",
|
||||
"\n",
|
||||
"You are now ready to fine-tune your model. Hugging Face TRL [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) makes it straightforward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additional quality of life features, including:\n",
|
||||
"\n",
|
||||
"* Dataset formatting, including conversational and instruction formats\n",
|
||||
"* Training on completions only, ignoring prompts\n",
|
||||
"* Packing datasets for more efficient training\n",
|
||||
"* Parameter-efficient fine-tuning (PEFT) support including QloRA\n",
|
||||
"* Preparing the model and tokenizer for conversational fine-tuning (such as adding special tokens)\n",
|
||||
"\n",
|
||||
"The following code loads the Gemma model and tokenizer from Hugging Face and initializes the quantization configuration.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "18069ed2"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "42e58727637d4495ad8c5f753c5bcd06",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b11ec04ab48043b9937cfa3822b4fa42",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model.safetensors: 0%| | 0.00/10.2G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7659ae83140247efacee26159ca363b6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/2011 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "28c7b23ad9ba4316a8c95992884ad1d7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"generation_config.json: 0%| | 0.00/149 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "44b08b5b2cad4385893e29d5240a98d7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"processor_config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6eec3330ff144b3c9ad863cc89ed5709",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"chat_template.jinja: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "47f8fdc1492e4bb9b8d8fe9535c97d2c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ff3072e44aec41b6a0f6a28aeba99c4d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3d3e0871ad0e4642a5e2ca6f4baeebe4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"tokenizer.json: 0%| | 0.00/32.2M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig\n",
|
||||
"\n",
|
||||
"# Hugging Face model id\n",
|
||||
"model_id = \"google/gemma-4-E2B\" # @param [\"google/gemma-4-E2B\",\"google/gemma-4-E4B\"] {\"allow-input\":true}\n",
|
||||
"\n",
|
||||
"# Check if GPU benefits from bfloat16\n",
|
||||
"if torch.cuda.get_device_capability()[0] < 8:\n",
|
||||
" raise ValueError(\"GPU does not support bfloat16, please use a GPU that supports bfloat16.\")\n",
|
||||
"\n",
|
||||
"# Define model init arguments\n",
|
||||
"model_kwargs = dict(\n",
|
||||
" dtype=torch.bfloat16, # What torch dtype to use, defaults to auto\n",
|
||||
" device_map=\"auto\", # Let torch decide how to load the model\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# BitsAndBytesConfig int-4 config\n",
|
||||
"model_kwargs[\"quantization_config\"] = BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||
" bnb_4bit_compute_dtype=model_kwargs[\"dtype\"],\n",
|
||||
" bnb_4bit_quant_storage=model_kwargs[\"dtype\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load model and tokenizer\n",
|
||||
"model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)\n",
|
||||
"processor = AutoProcessor.from_pretrained(\"google/gemma-4-E2B-it\") # Load the Instruction Tokenizer to use the official Gemma template"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "37ec1d1b"
|
||||
},
|
||||
"source": [
|
||||
"The `SFTTrainer` supports a built-in integration with `peft`, which makes it straightforward to efficiently tune LLMs using QLoRA. You only need to create a `LoraConfig` and provide it to the trainer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ed00e846"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from peft import LoraConfig\n",
|
||||
"\n",
|
||||
"peft_config = LoraConfig(\n",
|
||||
" lora_alpha=16,\n",
|
||||
" lora_dropout=0.05,\n",
|
||||
" r=16,\n",
|
||||
" bias=\"none\",\n",
|
||||
" target_modules=\"all-linear\",\n",
|
||||
" task_type=\"CAUSAL_LM\",\n",
|
||||
" modules_to_save=[\"lm_head\", \"embed_tokens\"], # make sure to save the lm_head and embed_tokens as you train the special tokens\n",
|
||||
" ensure_weight_tying=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bbd9fc1b"
|
||||
},
|
||||
"source": [
|
||||
"Before you can start your training, you need to define the hyperparameter you want to use in a `SFTConfig` and a custom `collate_fn` to handle the vision processing. The `collate_fn` converts the messages with text and images into a format that the model can understand.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "989be3c1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import SFTConfig\n",
|
||||
"\n",
|
||||
"args = SFTConfig(\n",
|
||||
" output_dir=\"gemma-product-description\", # directory to save and repository id\n",
|
||||
" num_train_epochs=3, # number of training epochs\n",
|
||||
" per_device_train_batch_size=1, # batch size per device during training\n",
|
||||
" optim=\"adamw_torch_fused\", # use fused adamw optimizer\n",
|
||||
" logging_steps=5, # log every 5 steps\n",
|
||||
" save_strategy=\"epoch\", # save checkpoint every epoch\n",
|
||||
" eval_strategy=\"epoch\", # evaluate checkpoint every epoch\n",
|
||||
" learning_rate=2e-4, # learning rate, based on QLoRA paper\n",
|
||||
" bf16=True, # use bfloat16 precision\n",
|
||||
" max_grad_norm=0.3, # max gradient norm based on QLoRA paper\n",
|
||||
" lr_scheduler_type=\"constant\", # use constant learning rate scheduler\n",
|
||||
" push_to_hub=True, # push model to hub\n",
|
||||
" report_to=\"tensorboard\", # report metrics to tensorboard\n",
|
||||
" dataset_text_field=\"\", # need a dummy field for collator\n",
|
||||
" dataset_kwargs={\"skip_prepare_dataset\": True}, # important for collator\n",
|
||||
" remove_unused_columns = False # important for collator\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create a data collator to encode text and image pairs\n",
|
||||
"def collate_fn(examples):\n",
|
||||
" texts = []\n",
|
||||
" images = []\n",
|
||||
" for example in examples:\n",
|
||||
" image_inputs = process_vision_info(example[\"messages\"])\n",
|
||||
" text = processor.apply_chat_template(\n",
|
||||
" example[\"messages\"], add_generation_prompt=False, tokenize=False\n",
|
||||
" )\n",
|
||||
" texts.append(text.strip())\n",
|
||||
" images.append(image_inputs)\n",
|
||||
"\n",
|
||||
" # Tokenize the texts and process the images\n",
|
||||
" batch = processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n",
|
||||
"\n",
|
||||
" # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation\n",
|
||||
" labels = batch[\"input_ids\"].clone()\n",
|
||||
"\n",
|
||||
" # Mask tokens for not being used in the loss computation\n",
|
||||
" labels[labels == processor.tokenizer.pad_token_id] = -100\n",
|
||||
" labels[labels == processor.tokenizer.boi_token_id] = -100\n",
|
||||
" labels[labels == processor.tokenizer.image_token_id] = -100\n",
|
||||
" labels[labels == processor.tokenizer.eoi_token_id] = -100\n",
|
||||
"\n",
|
||||
" batch[\"labels\"] = labels\n",
|
||||
" return batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dd88e798"
|
||||
},
|
||||
"source": [
|
||||
"You now have every building block you need to create your `SFTTrainer` to start the training of your model.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ade95df7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import SFTTrainer\n",
|
||||
"\n",
|
||||
"# Create Trainer object\n",
|
||||
"trainer = SFTTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=args,\n",
|
||||
" train_dataset=dataset_train,\n",
|
||||
" eval_dataset=dataset_test,\n",
|
||||
" peft_config=peft_config,\n",
|
||||
" processing_class=processor,\n",
|
||||
" data_collator=collate_fn,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "fad61a6a"
|
||||
},
|
||||
"source": [
|
||||
"Start training by calling the `train()` method."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "995e7e38"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='456' max='456' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [456/456 11:20, Epoch 3/3]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Epoch</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" <th>Validation Loss</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>1.326710</td>\n",
|
||||
" <td>1.441816</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>1.042711</td>\n",
|
||||
" <td>1.320613</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>0.739179</td>\n",
|
||||
" <td>1.458798</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start training, the model will be automatically saved to the Hub and the output directory\n",
|
||||
"trainer.train()\n",
|
||||
"\n",
|
||||
"# Save the final model again to the Hugging Face Hub\n",
|
||||
"trainer.save_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "b47b9733"
|
||||
},
|
||||
"source": [
|
||||
"Before you can test your model, make sure to free the memory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "40a32ed7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# free the memory again\n",
|
||||
"del model\n",
|
||||
"del trainer\n",
|
||||
"torch.cuda.empty_cache()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "862e9728"
|
||||
},
|
||||
"source": [
|
||||
"When using QLoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with serving stacks like vLLM or TGI, you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This saves a default model, which can be used for inference.\n",
|
||||
"\n",
|
||||
"Note: It requires more than 30GB of CPU Memory when you want to merge the adapter into the model. You can skip this and continue with Test Model Inference.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "761e324b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "20d63c526a854f2a880882c246ac3b3d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/2011 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8703db4619fd4f8eb66bf0cc2211dc7e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Writing model shards: 0%| | 0/5 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['merged_model/processor_config.json']"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from peft import PeftModel\n",
|
||||
"\n",
|
||||
"# Load Model base model\n",
|
||||
"model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)\n",
|
||||
"\n",
|
||||
"# Merge LoRA and base model and save\n",
|
||||
"peft_model = PeftModel.from_pretrained(model, args.output_dir)\n",
|
||||
"merged_model = peft_model.merge_and_unload()\n",
|
||||
"merged_model.save_pretrained(\"merged_model\", safe_serialization=True, max_shard_size=\"2GB\")\n",
|
||||
"\n",
|
||||
"processor = AutoProcessor.from_pretrained(args.output_dir)\n",
|
||||
"processor.save_pretrained(\"merged_model\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bf86e31d"
|
||||
},
|
||||
"source": [
|
||||
"## Test Model Inference and generate product descriptions\n",
|
||||
"\n",
|
||||
"After the training is done, you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.\n",
|
||||
"\n",
|
||||
"Note: Evaluating Generative AI models is not a trivial task since one input can have multiple correct outputs. This guide only focuses on manual evaluation and vibe checks.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "aab1c5c5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "89b0d1d25dba4e8e8642c41e69c4c65e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/2012 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_id = \"merged_model\"\n",
|
||||
"\n",
|
||||
"# Load Model with PEFT adapter\n",
|
||||
"model = AutoModelForImageTextToText.from_pretrained(\n",
|
||||
" model_id,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" dtype=\"auto\",\n",
|
||||
")\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "3dccb57c"
|
||||
},
|
||||
"source": [
|
||||
"You can test inference by providing a product name, category and image. The `sample` includes a marvel action figure.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1fd887f4"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<bos><|turn>system\n",
|
||||
"You are an expert product description writer for Amazon.<turn|>\n",
|
||||
"<|turn>user\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"<|image|>\n",
|
||||
"\n",
|
||||
"Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\n",
|
||||
"Only return description. The description should be SEO optimized and for a better mobile search experience.\n",
|
||||
"\n",
|
||||
"<PRODUCT>\n",
|
||||
"Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur\n",
|
||||
"</PRODUCT>\n",
|
||||
"\n",
|
||||
"<CATEGORY>\n",
|
||||
"Toys & Games | Toy Figures & Playsets | Action Figures\n",
|
||||
"</CATEGORY><turn|>\n",
|
||||
"<|turn>model\n",
|
||||
"\n",
|
||||
"MODEL OUTPUT>> \n",
|
||||
"\n",
|
||||
"Enhance your collection with the Marvel Avengers - Avengers Assemble Ultron-Comforter Set! This soft and cuddly blanket and pillowcase feature everyone's favorite Avengers, Iron Man, and his loyal companion War Machine. Officially licensed by Marvel. Bring home the heroic team!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"# Test sample with Product Name, Category and Image\n",
|
||||
"sample = {\n",
|
||||
" \"product_name\": \"Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur\",\n",
|
||||
" \"category\": \"Toys & Games | Toy Figures & Playsets | Action Figures\",\n",
|
||||
" \"image\": Image.open(requests.get(\"https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg\", stream=True).raw).convert(\"RGB\")\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"def generate_description(sample, model, processor):\n",
|
||||
" # Convert sample into messages and then apply the chat template\n",
|
||||
" messages = [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\",\"image\": sample[\"image\"]},\n",
|
||||
" {\"type\": \"text\", \"text\": user_prompt.format(product=sample[\"product_name\"], category=sample[\"category\"])},\n",
|
||||
" ]},\n",
|
||||
" ]\n",
|
||||
" text = processor.apply_chat_template(\n",
|
||||
" messages, tokenize=False, add_generation_prompt=True\n",
|
||||
" )\n",
|
||||
" print(text)\n",
|
||||
" # Process the image and text\n",
|
||||
" image_inputs = process_vision_info(messages)\n",
|
||||
" # Tokenize the text and process the images\n",
|
||||
" inputs = processor(\n",
|
||||
" text=[text],\n",
|
||||
" images=image_inputs,\n",
|
||||
" padding=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" )\n",
|
||||
" # Move the inputs to the device\n",
|
||||
" inputs = inputs.to(model.device)\n",
|
||||
"\n",
|
||||
" # Generate the output\n",
|
||||
" stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids(\"<turn|>\")]\n",
|
||||
" generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)\n",
|
||||
" # Trim the generation and decode the output to text\n",
|
||||
" generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]\n",
|
||||
" output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
" )\n",
|
||||
" return output_text[0]\n",
|
||||
"\n",
|
||||
"# generate the description\n",
|
||||
"description = generate_description(sample, model, processor)\n",
|
||||
"print(\"MODEL OUTPUT>> \\n\")\n",
|
||||
"print(description)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "6f8ff452"
|
||||
},
|
||||
"source": [
|
||||
"## Summary and next steps\n",
|
||||
"\n",
|
||||
"This tutorial covered how to fine-tune a Gemma model for vision tasks using TRL and QLoRA, specifically for generating product descriptions. Check out the following docs next:\n",
|
||||
"\n",
|
||||
"* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
|
||||
"* Learn how to [fine-tune Gemma for text tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora).\n",
|
||||
"* Learn how to [full model fine-tune using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune).\n",
|
||||
"* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
|
||||
"* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
|
||||
"* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"name": "huggingface_vision_finetune_qlora.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,789 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "G3MMAcssHTML"
|
||||
},
|
||||
"source": [
|
||||
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
|
||||
"<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Tce3stUlHN0L"
|
||||
},
|
||||
"source": [
|
||||
"##### Copyright 2025 Google LLC."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "tuOe1ymfHZPu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SDEExiAk4fLb"
|
||||
},
|
||||
"source": [
|
||||
"# Fine-tune Gemma in Keras using LoRA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ZFWzQEqNosrS"
|
||||
},
|
||||
"source": [
|
||||
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/lora_tuning\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/docs/core/lora_tuning.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/google-gemini/gemma-cookbook/blob/main/docs/core/lora_tuning.ipynb\"><img src=\"https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png\" height=\"32\" width=\"70\"/>Run in Kaggle</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fgoogle-gemini%2Fgemma-cookbook%2Fmain%2Fdocs%2Fcore%2Flora_tuning.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://github.com/google-gemini/gemma-cookbook/blob/main/docs/core/lora_tuning.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "lSGRSsRPgkzK"
|
||||
},
|
||||
"source": [
|
||||
"Generative artificial intelligent (AI) models like Gemma are effective at a variety of tasks. You can further fine-tune Gemma models with domain-specific data to perform tasks such as sentiment analysis. However, full fine-tuning of generative models by updating billions of parameters is resource intensive, requiring specialized hardware, such as GPUs, processing time, and memory to load the model parameters.\n",
|
||||
"\n",
|
||||
"[Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This technique makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs. This tutorial walks you through using Keras to perform LoRA fine-tuning on a Gemma model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "lyhHCMfoRZ_v"
|
||||
},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
|
||||
"\n",
|
||||
"* Get access to Gemma on [kaggle.com](https://kaggle.com).\n",
|
||||
"* Select a Colab runtime with sufficient resources to tune\n",
|
||||
" the Gemma model you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n",
|
||||
"* Generate and configure a Kaggle username and API key.\n",
|
||||
"\n",
|
||||
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AZ5Qo0fxRZ1V"
|
||||
},
|
||||
"source": [
|
||||
"### Select a Colab runtime\n",
|
||||
"\n",
|
||||
"To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
|
||||
"\n",
|
||||
"1. In the upper-right of the Colab window, select ▾ (**Additional connection options**).\n",
|
||||
"2. Select **Change runtime type**.\n",
|
||||
"3. Under **Hardware accelerator**, select **T4 GPU**."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hsPC0HRkJl0K"
|
||||
},
|
||||
"source": [
|
||||
"### Configure your API key\n",
|
||||
"\n",
|
||||
"To use Gemma, you must provide your Kaggle username and a Kaggle API key.\n",
|
||||
"\n",
|
||||
"To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
|
||||
"\n",
|
||||
"In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7iOF6Yo-wUEC"
|
||||
},
|
||||
"source": [
|
||||
"### Set environment variables\n",
|
||||
"\n",
|
||||
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0_EdOg9DPK6Q"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
|
||||
"# vars as appropriate for your system.\n",
|
||||
"\n",
|
||||
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
|
||||
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "CuEUAKJW1QkQ"
|
||||
},
|
||||
"source": [
|
||||
"### Install Keras packages\n",
|
||||
"\n",
|
||||
"Install the Keras and KerasHub Python packages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1eeBtYqJsZPG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q -U keras-hub\n",
|
||||
"!pip install -q -U keras"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "rGLS-l5TxIR4"
|
||||
},
|
||||
"source": [
|
||||
"### Select a backend\n",
|
||||
"\n",
|
||||
"Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch. For this tutorial, configure the backend for JAX as it typically provides the better performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "yn5uy8X8sdD0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"torch\" or \"tensorflow\".\n",
|
||||
"# Avoid memory fragmentation on JAX backend.\n",
|
||||
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hZs8XXqUKRmi"
|
||||
},
|
||||
"source": [
|
||||
"### Import packages\n",
|
||||
"\n",
|
||||
"Import the Python packages needed for this tutorial, including Keras and KerasHub."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "FYHyPUA9hKTf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import keras\n",
|
||||
"import keras_hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7RCE3fdGhDE5"
|
||||
},
|
||||
"source": [
|
||||
"## Load model\n",
|
||||
"\n",
|
||||
"Keras provides implementations of Gemma and many other popular [model architectures](https://keras.io/keras_hub/api/models/). Use the `Gemma3CausalLM.from_preset()` method to configure an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "vz5zLEyLstfn"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\"gemma3_instruct_1b\")\n",
|
||||
"gemma_lm.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Nl4lvPy5zA26"
|
||||
},
|
||||
"source": [
|
||||
"The `Gemma3CausalLM.from_preset()` method instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma#_xxxxxxx\"` specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/keras/gemma3)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "G_L6A5J-1QgC"
|
||||
},
|
||||
"source": [
|
||||
"## Inference before fine tuning\n",
|
||||
"\n",
|
||||
"Once you have downloaded and configured a Gemma model, you can query it with various prompts to see how it responds."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PVLXadptyo34"
|
||||
},
|
||||
"source": [
|
||||
"### Europe trip prompt\n",
|
||||
"\n",
|
||||
"Query the model for suggestions on what to do on a trip to Europe."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZwQz3xxxKciD"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"What should I do on a trip to Europe?\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"The first thing to know is that you will have a great time!\n",
|
||||
"\n",
|
||||
"Europe is a great place for a vacation. The countries of Europe are all very different and offer a wide range of activities and attractions. The countries of Europe are also very close to each other, which means you can visit many different places within a short time.\n",
|
||||
"\n",
|
||||
"The best way to plan a trip to Europe is to look up the countries you want to visit and see what activities are offered in each country. You can also look for tours and tours that offer a good value for money.\n",
|
||||
"\n",
|
||||
"You can also look for hotels and flights that offer good deals. If you are looking for a good value for money, you should look for hotels and flights that offer good deals. This means you will have a great time on your trip!\n",
|
||||
"\n",
|
||||
"The next step is to book your tickets to the countries you want to visit. If you are planning to visit many countries, it's a good idea to book your tickets early. This means you’ll be able to get the best deal and avoid the long queues.\n",
|
||||
"\n",
|
||||
"The next step is to plan your itinerary. You can use a travel guide to plan your itinerary\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n",
|
||||
"\n",
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"What should I do on a trip to Europe?\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
|
||||
"gemma_lm.compile(sampler=sampler)\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AePQUIs2h-Ks"
|
||||
},
|
||||
"source": [
|
||||
"The model responds with generic tips on how to plan a trip."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YQ74Zz_S0iVv"
|
||||
},
|
||||
"source": [
|
||||
"### Photosynthesis prompt\n",
|
||||
"\n",
|
||||
"Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "lorJMbsusgoo"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"Explain the process of photosynthesis in a way that a child could understand.\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"Photosynthesis is a biological process that occurs in plants, algae, and some other organisms. In the process, light energy is captured and converted into the energy stored in the bonds of organic molecules. The process is crucial for life on Earth because it enables plants to use carbon dioxide and water to produce glucose and oxygen, which are essential for all living things.\n",
|
||||
"The process involves several stages:\n",
|
||||
"1. Light Reactions: Light energy is absorbed by pigments in the chloroplasts of the plant, converting it into chemical energy in the form of ATP and reducing power.\n",
|
||||
"2. Carbon Fixation: During this stage, carbon dioxide is combined with hydrogen to form organic molecules such as starch or glucose, which are used as a source of energy.\n",
|
||||
"3. Calvin Cycle: The process of carbon fixation occurs in the stroma of the chloroplasts. It involves the capture and reduction of carbon dioxide, producing glucose and reducing power in the form of ATP and NADPH molecules.\n",
|
||||
"4. Stroma: The stroma is the fluid-filled space where the light reactions occur in the chloroplasts.\n",
|
||||
"5. Chloroplasts: The chloroplasts contain the green pigments that absorb\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WBQieduRizZf"
|
||||
},
|
||||
"source": [
|
||||
"The model response contains words that might not be easy to understand for a child such as chlorophyll."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Pt7Nr6a7tItO"
|
||||
},
|
||||
"source": [
|
||||
"## LoRA fine-tuning\n",
|
||||
"\n",
|
||||
"This section shows you how to do fine-tuning using the Low Rank Adaptation (LoRA) tuning technique. This approach allows you to change the behavior of Gemma models using fewer compute resources."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9T7xe_jzslv4"
|
||||
},
|
||||
"source": [
|
||||
"### Load dataset\n",
|
||||
"\n",
|
||||
"Prepare a dataset for tuning by downloading an existing data set and formatting if for use with the the Keras `fit()` fine-tuning method. This tutorial uses the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k) for fine-tuning. The dataset contains 15,000 high-quality human-generated prompt and response pairs specifically designed for tuning generative models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "xRaNCPUXKoa7"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2025-04-10 20:48:49-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n",
|
||||
"Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.114, 3.163.189.74, ...\n",
|
||||
"Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.\n",
|
||||
"HTTP request sent, awaiting response... 302 Found\n",
|
||||
"Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
|
||||
"--2025-04-10 20:48:49-- https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE\n",
|
||||
"Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.238.217.63, 18.238.217.81, 18.238.217.120, ...\n",
|
||||
"Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.238.217.63|:443... connected.\n",
|
||||
"HTTP request sent, awaiting response... 200 OK\n",
|
||||
"Length: 13085339 (12M) [text/plain]\n",
|
||||
"Saving to: ‘databricks-dolly-15k.jsonl’\n",
|
||||
"\n",
|
||||
"databricks-dolly-15 100%[===================>] 12.48M --.-KB/s in 0.08s \n",
|
||||
"\n",
|
||||
"2025-04-10 20:48:49 (156 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "45UpBDfBgf0I"
|
||||
},
|
||||
"source": [
|
||||
"### Format tuning data\n",
|
||||
"\n",
|
||||
"Format the downloaded data for use with the Keras `fit()` method. The following code extracts a subset of the training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZiS-KU9osh_N"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"prompts = []\n",
|
||||
"responses = []\n",
|
||||
"line_count = 0\n",
|
||||
"\n",
|
||||
"with open(\"databricks-dolly-15k.jsonl\") as file:\n",
|
||||
" for line in file:\n",
|
||||
" if line_count >= 1000:\n",
|
||||
" break # Limit the training examples, to reduce execution time.\n",
|
||||
"\n",
|
||||
" examples = json.loads(line)\n",
|
||||
" # Filter out examples with context, to keep it simple.\n",
|
||||
" if examples[\"context\"]:\n",
|
||||
" continue\n",
|
||||
" # Format data into prompts and response lists.\n",
|
||||
" prompts.append(examples[\"instruction\"])\n",
|
||||
" responses.append(examples[\"response\"])\n",
|
||||
"\n",
|
||||
" line_count += 1\n",
|
||||
"\n",
|
||||
"data = {\n",
|
||||
" \"prompts\": prompts,\n",
|
||||
" \"responses\": responses\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "cBLW5hiGj31i"
|
||||
},
|
||||
"source": [
|
||||
"### Configure LoRA tuning\n",
|
||||
"\n",
|
||||
"Activate LoRA tuning using the Keras `model.backbone.enable_lora()` method, including a LoRA rank value. The *LoRA rank* determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments. A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n",
|
||||
"\n",
|
||||
"This example uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This setting is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "RCucu6oHz53G"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Enable LoRA for the model and set the LoRA rank to 4.\n",
|
||||
"gemma_lm.backbone.enable_lora(rank=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PlMLp_NVbRoQ"
|
||||
},
|
||||
"source": [
|
||||
"Check the model summary after setting the LoRA rank. Notice that enabling LoRA reduces the number of trainable parameters significantly compared to the total number of parameters in the model:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "KqYyS0gm6pNy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gemma_lm.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hQQ47kcdpbZ9"
|
||||
},
|
||||
"source": [
|
||||
"Configure the rest of the fine-tuning settings, including the preprocessor settings, optimizer, number of tuning epochs, and batch size:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "p9sBNH8SAjgB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Limit the input sequence length to 256 (to control memory usage).\n",
|
||||
"gemma_lm.preprocessor.sequence_length = 256\n",
|
||||
"# Use AdamW (a common optimizer for transformer models).\n",
|
||||
"optimizer = keras.optimizers.AdamW(\n",
|
||||
" learning_rate=5e-5,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
")\n",
|
||||
"# Exclude layernorm and bias terms from decay.\n",
|
||||
"optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
|
||||
"\n",
|
||||
"gemma_lm.compile(\n",
|
||||
" loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
|
||||
" optimizer=optimizer,\n",
|
||||
" weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OA0ozGC66tk1"
|
||||
},
|
||||
"source": [
|
||||
"### Run the fine-tune process\n",
|
||||
"\n",
|
||||
"Run the fine-tuning process using the `fit()` method. This process can take several minutes depending on your compute resources, data size, and number of epochs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "_Peq7TnLtHse"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m923s\u001b[0m 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<keras.src.callbacks.history.History at 0x799d04393c40>"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"gemma_lm.fit(data, epochs=1, batch_size=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bx3m8f1dB7nk"
|
||||
},
|
||||
"source": [
|
||||
"#### Mixed precision fine-tuning on NVIDIA GPUs\n",
|
||||
"\n",
|
||||
"Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "T0lHxEDX03gp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Uncomment the line below if you want to enable mixed precision training on GPUs\n",
|
||||
"# keras.mixed_precision.set_global_policy('mixed_bfloat16')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4yd-1cNw1dTn"
|
||||
},
|
||||
"source": [
|
||||
"## Inference after fine-tuning\n",
|
||||
"\n",
|
||||
"After fine-tuning, you should see changes in the responses when the tuned model is given the same prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "H55JYJ1a1Kos"
|
||||
},
|
||||
"source": [
|
||||
"### Europe trip prompt\n",
|
||||
"\n",
|
||||
"Try the Europe trip prompt from earlier and note the differences in the response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Y7cDJHy8WfCB"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"What should I do on a trip to Europe?\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"What should I do on a trip to Europe?\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
|
||||
"gemma_lm.compile(sampler=sampler)\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OXP6gg2mjs6u"
|
||||
},
|
||||
"source": [
|
||||
"The model now provides a shorter response to a question about visiting Europe."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "H7nVd8Mi1Yta"
|
||||
},
|
||||
"source": [
|
||||
"### Photosynthesis prompt\n",
|
||||
"\n",
|
||||
"Try the photosynthesis explanation prompt from earlier and note the differences in the response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "X-2sYl2jqwl7"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Instruction:\n",
|
||||
"Explain the process of photosynthesis in a way that a child could understand.\n",
|
||||
"\n",
|
||||
"Response:\n",
|
||||
"The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = template.format(\n",
|
||||
" instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
|
||||
" response=\"\",\n",
|
||||
")\n",
|
||||
"print(gemma_lm.generate(prompt, max_length=256))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PCmAmqrvkEhc"
|
||||
},
|
||||
"source": [
|
||||
"The model now explains photosynthesis in simpler terms."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "I8kFG12l0mVe"
|
||||
},
|
||||
"source": [
|
||||
"## Improving fine-tune results\n",
|
||||
"\n",
|
||||
"For demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:\n",
|
||||
"\n",
|
||||
"1. Increasing the size of the fine-tuning dataset\n",
|
||||
"2. Training for more steps (epochs)\n",
|
||||
"3. Setting a higher LoRA rank\n",
|
||||
"4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gSsRdeiof_rJ"
|
||||
},
|
||||
"source": [
|
||||
"## Summary and next steps\n",
|
||||
"\n",
|
||||
"This tutorial covered LoRA fine-tuning on a Gemma model using Keras. Check out the following docs next:\n",
|
||||
"\n",
|
||||
"* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
|
||||
"* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
|
||||
"* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
|
||||
"* Learn how to [fine-tune Gemma using Keras and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"name": "lora_tuning.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user