eecebe7ef5
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1120 lines
59 KiB
Plaintext
1120 lines
59 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "UUYMxQuf8zGu"
|
|
},
|
|
"source": [
|
|
"##### Copyright 2025 Google LLC."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "3x6t11lI829b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"# you may not use this file except in compliance with the License.\n",
|
|
"# You may obtain a copy of the License at\n",
|
|
"#\n",
|
|
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"#\n",
|
|
"# Unless required by applicable law or agreed to in writing, software\n",
|
|
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"# See the License for the specific language governing permissions and\n",
|
|
"# limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WJwmK4C087wa"
|
|
},
|
|
"source": [
|
|
"# Fine-Tune Gemma using Hugging Face Transformers and QloRA"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "f9673bd6"
|
|
},
|
|
"source": [
|
|
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
|
" <td>\n",
|
|
" <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
|
|
" </td>\n",
|
|
" <td>\n",
|
|
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/docs/core/huggingface_text_finetune_qlora.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
|
" </td>\n",
|
|
" <td>\n",
|
|
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/google-gemma/cookbook/blob/main/docs/core/huggingface_text_finetune_qlora.ipynb\"><img src=\"https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png\" height=\"32\" width=\"70\"/>Run in Kaggle</a>\n",
|
|
" </td>\n",
|
|
" <td>\n",
|
|
" <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fgoogle-gemma%2Fcookbook%2Fmain%2Fdocs%2Fcore%2Fhuggingface_text_finetune_qlora.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
|
|
" </td>\n",
|
|
" <td>\n",
|
|
" <a target=\"_blank\" href=\"https://github.com/google-gemma/cookbook/blob/main/docs/core/huggingface_text_finetune_qlora.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
|
" </td>\n",
|
|
"</table>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "e624ec07"
|
|
},
|
|
"source": [
|
|
"This guide walks you through how to fine-tune Gemma on a custom text-to-sql dataset using Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) and [TRL](https://huggingface.co/docs/trl/index). You will learn:\n",
|
|
"\n",
|
|
"- What is Quantized Low-Rank Adaptation (QLoRA)\n",
|
|
"- Setup development environment\n",
|
|
"- Create and prepare the fine-tuning dataset\n",
|
|
"- Fine-tune Gemma using TRL and the SFTTrainer\n",
|
|
"- Test Model Inference and generate SQL queries\n",
|
|
"\n",
|
|
"Note: This guide was created to run on a Google colaboratory account using a NVIDIA T4 GPU with 16GB and Gemma 1B, but can be adapted to run on bigger GPUs and bigger models.\n",
|
|
"\n",
|
|
"## What is Quantized Low-Rank Adaptation (QLoRA)\n",
|
|
"\n",
|
|
"This guide demonstrates the use of [Quantized Low-Rank Adaptation (QLoRA)](https://arxiv.org/abs/2305.14314), which emerged as a popular method to efficiently fine-tune LLMs as it reduces computational resource requirements while maintaining high performance. In QloRA, the pretrained model is quantized to 4-bit and the weights are frozen. Then trainable adapter layers (LoRA) are attached and only the adapter layers are trained. Afterwards, the adapter weights can be merged with the base model or kept as a separate adapter.\n",
|
|
"\n",
|
|
"## Setup development environment\n",
|
|
"\n",
|
|
"The first step is to install Hugging Face Libraries, including TRL, and datasets to fine-tune open model, including different RLHF and alignment techniques."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ba51aa79"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.10.0+cu128)\n",
|
|
"Requirement already satisfied: tensorboard in /usr/local/lib/python3.12/dist-packages (2.19.0)\n",
|
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.25.2)\n",
|
|
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)\n",
|
|
"Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)\n",
|
|
"Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.14.0)\n",
|
|
"Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch) (3.6.1)\n",
|
|
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)\n",
|
|
"Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)\n",
|
|
"Requirement already satisfied: cuda-bindings==12.9.4 in /usr/local/lib/python3.12/dist-packages (from torch) (12.9.4)\n",
|
|
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.93)\n",
|
|
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.90)\n",
|
|
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.90)\n",
|
|
"Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch) (9.10.2.21)\n",
|
|
"Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.4.1)\n",
|
|
"Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch) (11.3.3.83)\n",
|
|
"Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch) (10.3.9.90)\n",
|
|
"Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch) (11.7.3.90)\n",
|
|
"Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch) (12.5.8.93)\n",
|
|
"Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch) (0.7.1)\n",
|
|
"Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch) (2.27.5)\n",
|
|
"Requirement already satisfied: nvidia-nvshmem-cu12==3.4.5 in /usr/local/lib/python3.12/dist-packages (from torch) (3.4.5)\n",
|
|
"Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.90)\n",
|
|
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.93)\n",
|
|
"Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.13.1.3)\n",
|
|
"Requirement already satisfied: triton==3.6.0 in /usr/local/lib/python3.12/dist-packages (from torch) (3.6.0)\n",
|
|
"Requirement already satisfied: cuda-pathfinder~=1.1 in /usr/local/lib/python3.12/dist-packages (from cuda-bindings==12.9.4->torch) (1.4.3)\n",
|
|
"Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (1.4.0)\n",
|
|
"Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (1.78.0)\n",
|
|
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (3.10.2)\n",
|
|
"Requirement already satisfied: numpy>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (2.0.2)\n",
|
|
"Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from tensorboard) (26.0)\n",
|
|
"Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (5.29.6)\n",
|
|
"Requirement already satisfied: six>1.9 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (1.17.0)\n",
|
|
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (0.7.2)\n",
|
|
"Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (3.1.6)\n",
|
|
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n",
|
|
"Requirement already satisfied: markupsafe>=2.1.1 in /usr/local/lib/python3.12/dist-packages (from werkzeug>=1.0.1->tensorboard) (3.0.3)\n",
|
|
"Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0)\n",
|
|
"Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (1.13.0)\n",
|
|
"Collecting evaluate\n",
|
|
" Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)\n",
|
|
"Collecting bitsandbytes\n",
|
|
" Downloading bitsandbytes-0.49.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)\n",
|
|
"Collecting trl\n",
|
|
" Downloading trl-1.0.0-py3-none-any.whl.metadata (11 kB)\n",
|
|
"Requirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (0.18.1)\n",
|
|
"Requirement already satisfied: protobuf in /usr/local/lib/python3.12/dist-packages (5.29.6)\n",
|
|
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.12/dist-packages (0.2.1)\n",
|
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets) (3.25.2)\n",
|
|
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.0.2)\n",
|
|
"Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0)\n",
|
|
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8)\n",
|
|
"Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2)\n",
|
|
"Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n",
|
|
"Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets) (4.67.3)\n",
|
|
"Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.6.0)\n",
|
|
"Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\n",
|
|
"Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)\n",
|
|
"Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (1.7.1)\n",
|
|
"Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets) (26.0)\n",
|
|
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets) (6.0.3)\n",
|
|
"Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate) (5.9.5)\n",
|
|
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from accelerate) (2.10.0+cu128)\n",
|
|
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from accelerate) (0.7.0)\n",
|
|
"Collecting datasets\n",
|
|
" Downloading datasets-4.8.4-py3-none-any.whl.metadata (19 kB)\n",
|
|
"Requirement already satisfied: transformers>=4.56.2 in /usr/local/lib/python3.12/dist-packages (from trl) (5.5.0.dev0)\n",
|
|
"Collecting pyarrow>=21.0.0 (from datasets)\n",
|
|
" Downloading pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.1 kB)\n",
|
|
"Requirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.28.1)\n",
|
|
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.13.3)\n",
|
|
"Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
|
|
"Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (2026.2.25)\n",
|
|
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
|
|
"Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (3.11)\n",
|
|
"Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
|
|
"Requirement already satisfied: hf-xet<2.0.0,>=1.4.2 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.4.2)\n",
|
|
"Requirement already satisfied: typer in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (0.24.1)\n",
|
|
"Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.15.0)\n",
|
|
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.6)\n",
|
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n",
|
|
"Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (75.2.0)\n",
|
|
"Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.14.0)\n",
|
|
"Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.6.1)\n",
|
|
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.1.6)\n",
|
|
"Requirement already satisfied: cuda-bindings==12.9.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.9.4)\n",
|
|
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)\n",
|
|
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)\n",
|
|
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)\n",
|
|
"Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (9.10.2.21)\n",
|
|
"Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.4.1)\n",
|
|
"Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.3.3.83)\n",
|
|
"Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (10.3.9.90)\n",
|
|
"Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.7.3.90)\n",
|
|
"Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.5.8.93)\n",
|
|
"Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (0.7.1)\n",
|
|
"Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (2.27.5)\n",
|
|
"Requirement already satisfied: nvidia-nvshmem-cu12==3.4.5 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.4.5)\n",
|
|
"Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)\n",
|
|
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)\n",
|
|
"Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.13.1.3)\n",
|
|
"Requirement already satisfied: triton==3.6.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.6.0)\n",
|
|
"Requirement already satisfied: cuda-pathfinder~=1.1 in /usr/local/lib/python3.12/dist-packages (from cuda-bindings==12.9.4->torch>=2.0.0->accelerate) (1.4.3)\n",
|
|
"Requirement already satisfied: regex>=2025.10.22 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.56.2->trl) (2025.11.3)\n",
|
|
"Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.56.2->trl) (0.22.2)\n",
|
|
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\n",
|
|
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\n",
|
|
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.3)\n",
|
|
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.1)\n",
|
|
"Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)\n",
|
|
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (25.4.0)\n",
|
|
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.8.0)\n",
|
|
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.7.1)\n",
|
|
"Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.4.1)\n",
|
|
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.23.0)\n",
|
|
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
|
|
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0)\n",
|
|
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.3)\n",
|
|
"Requirement already satisfied: click>=8.2.1 in /usr/local/lib/python3.12/dist-packages (from typer->huggingface-hub>=0.24.0->datasets) (8.3.1)\n",
|
|
"Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer->huggingface-hub>=0.24.0->datasets) (1.5.4)\n",
|
|
"Requirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.12/dist-packages (from typer->huggingface-hub>=0.24.0->datasets) (13.9.4)\n",
|
|
"Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer->huggingface-hub>=0.24.0->datasets) (0.0.4)\n",
|
|
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer->huggingface-hub>=0.24.0->datasets) (4.0.0)\n",
|
|
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer->huggingface-hub>=0.24.0->datasets) (2.19.2)\n",
|
|
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->huggingface-hub>=0.24.0->datasets) (0.1.2)\n",
|
|
"Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[?25hDownloading bitsandbytes-0.49.2-py3-none-manylinux_2_24_x86_64.whl (60.7 MB)\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.7/60.7 MB\u001b[0m \u001b[31m38.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[?25hDownloading trl-1.0.0-py3-none-any.whl (630 kB)\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m630.8/630.8 kB\u001b[0m \u001b[31m64.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[?25hDownloading datasets-4.8.4-py3-none-any.whl (526 kB)\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m527.0/527.0 kB\u001b[0m \u001b[31m46.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[?25hDownloading pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl (47.6 MB)\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m47.6/47.6 MB\u001b[0m \u001b[31m55.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[?25hInstalling collected packages: pyarrow, bitsandbytes, datasets, evaluate, trl\n",
|
|
" Attempting uninstall: pyarrow\n",
|
|
" Found existing installation: pyarrow 18.1.0\n",
|
|
" Uninstalling pyarrow-18.1.0:\n",
|
|
" Successfully uninstalled pyarrow-18.1.0\n",
|
|
" Attempting uninstall: datasets\n",
|
|
" Found existing installation: datasets 4.0.0\n",
|
|
" Uninstalling datasets-4.0.0:\n",
|
|
" Successfully uninstalled datasets-4.0.0\n",
|
|
"Successfully installed bitsandbytes-0.49.2 datasets-4.8.4 evaluate-0.4.6 pyarrow-23.0.1 trl-1.0.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Install Pytorch & other libraries\n",
|
|
"%pip install torch tensorboard\n",
|
|
"\n",
|
|
"# Install Transformers\n",
|
|
"%pip install transformers\n",
|
|
"\n",
|
|
"# Install Hugging Face libraries\n",
|
|
"%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf sentencepiece\n",
|
|
"\n",
|
|
"# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100\n",
|
|
"#%pip install flash-attn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7ef3d54b"
|
|
},
|
|
"source": [
|
|
"_Note: If you are using a GPU with Ampere architecture (such as NVIDIA L4) or newer, you can use Flash attention. Flash Attention is a method that significantly speeds computations up and reduces memory usage from quadratic to linear in sequence length, leading to acelerating training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main)._\n",
|
|
"\n",
|
|
"You need a valid Hugging Face Token to publish your model. If you are running inside a Google Colab, you can securely use your Hugging Face Token using the Colab secrets otherwise you can set the token as directly in the `login` method. Make sure your token has write access too, as you push your model to the Hub during training."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "b6d79c93"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Login into Hugging Face Hub\n",
|
|
"from huggingface_hub import login\n",
|
|
"login()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "42c60525"
|
|
},
|
|
"source": [
|
|
"## Create and prepare the fine-tuning dataset\n",
|
|
"\n",
|
|
"When fine-tuning LLMs, it is important to know your use case and the task you want to solve. This helps you create a dataset to fine-tune your model. If you haven't defined your use case yet, you might want to go back to the drawing board.\n",
|
|
"\n",
|
|
"As an example, this guide focuses on the following use case:\n",
|
|
"\n",
|
|
"- Fine-tune a natural language to SQL model for seamless integration into a data analysis tool. The objective is to significantly reduce the time and expertise required for SQL query generation, enabling even non-technical users to extract meaningful insights from data.\n",
|
|
"\n",
|
|
"Text-to-SQL can be a good use case for fine-tuning LLMs, as it is a complex task that requires a lot of (internal) knowledge about the data and the SQL language.\n",
|
|
"\n",
|
|
"Once you have determined that fine-tuning is the right solution, you need a dataset to fine-tune. The dataset should be a diverse set of demonstrations of the task(s) you want to solve. There are several ways to create such a dataset, including:\n",
|
|
"\n",
|
|
"- Using existing open-source datasets, such as [Spider](https://huggingface.co/datasets/spider)\n",
|
|
"- Using synthetic datasets created by LLMs, such as [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)\n",
|
|
"- Using datasets created by humans, such as [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k).\n",
|
|
"- Using a combination of the methods, such as [Orca](https://huggingface.co/datasets/Open-Orca/OpenOrca)\n",
|
|
"\n",
|
|
"Each of the methods has its own advantages and disadvantages and depends on the budget, time, and quality requirements. For example, using an existing dataset is the easiest but might not be tailored to your specific use case, while using domain experts might be the most accurate but can be time-consuming and expensive. It is also possible to combine several methods to create an instruction dataset, as shown in [Orca: Progressive Learning from Complex Explanation Traces of GPT-4.](https://arxiv.org/abs/2306.02707)\n",
|
|
"\n",
|
|
"This guide uses an already existing dataset ([philschmid/gretel-synthetic-text-to-sql](https://huggingface.co/datasets/philschmid/gretel-synthetic-text-to-sql)), a high quality synthetic Text-to-SQL dataset including natural language instructions, schema definitions, reasoning and the corresponding SQL query.\n",
|
|
"\n",
|
|
"[Hugging Face TRL](https://huggingface.co/docs/trl/en/index) supports automatic templating of conversation dataset formats. This means you only need to convert your dataset into the right json objects, and `trl` takes care of templating and putting it into the right format.\n",
|
|
"\n",
|
|
"```\n",
|
|
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
|
|
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
|
|
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c4ecf6db"
|
|
},
|
|
"source": [
|
|
"The [philschmid/gretel-synthetic-text-to-sql](https://huggingface.co/datasets/philschmid/gretel-synthetic-text-to-sql) contains over 100k samples. To keep the guide small, it is downsampled to only use 10,000 samples.\n",
|
|
"\n",
|
|
"You can now use the Hugging Face Datasets library to load the dataset and create a prompt template to combine the natural language instruction, schema definition and add a system message for your assistant."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "40c3a2cf"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "6e1947559b8c42f0ab2cf28efc6535b7",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"README.md: 0%| | 0.00/737 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "5568dea8cc914e5ca102d0da61ad8238",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"synthetic_text_to_sql_train.snappy.parqu(…): 0%| | 0.00/32.4M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "990f6c69a0104030bf1b6ebe1caf4734",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"synthetic_text_to_sql_test.snappy.parque(…): 0%| | 0.00/1.90M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "61383ce954334b9bbf27a9ae655669a2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Generating train split: 0%| | 0/100000 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "64252f9a3e9a4b7093dbe207b1c90f3d",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Generating test split: 0%| | 0/5851 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "cec15eeff0bb484aa2039f41ce36c5b7",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Map: 0%| | 0/12500 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'content': 'You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.', 'role': 'system'}\n",
|
|
"{'content': \"Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\\n\\n<SCHEMA>\\nCREATE TABLE Menu (id INT PRIMARY KEY, name VARCHAR(255), category VARCHAR(255), price DECIMAL(5,2));\\n</SCHEMA>\\n\\n<USER_QUERY>\\nCalculate the average price of all menu items in the Vegan category\\n</USER_QUERY>\\n\", 'role': 'user'}\n",
|
|
"{'content': \"SELECT AVG(price) FROM Menu WHERE category = 'Vegan';\", 'role': 'assistant'}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from datasets import load_dataset\n",
|
|
"\n",
|
|
"# System message for the assistant\n",
|
|
"system_message = \"\"\"You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\"\"\"\n",
|
|
"\n",
|
|
"# User prompt that combines the user query and the schema\n",
|
|
"user_prompt = \"\"\"Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n",
|
|
"\n",
|
|
"<SCHEMA>\n",
|
|
"{context}\n",
|
|
"</SCHEMA>\n",
|
|
"\n",
|
|
"<USER_QUERY>\n",
|
|
"{question}\n",
|
|
"</USER_QUERY>\n",
|
|
"\"\"\"\n",
|
|
"def create_conversation(sample):\n",
|
|
" return {\n",
|
|
" \"messages\": [\n",
|
|
" {\"role\": \"system\", \"content\": system_message},\n",
|
|
" {\"role\": \"user\", \"content\": user_prompt.format(question=sample[\"sql_prompt\"], context=sample[\"sql_context\"])},\n",
|
|
" {\"role\": \"assistant\", \"content\": sample[\"sql\"]}\n",
|
|
" ]\n",
|
|
" }\n",
|
|
"\n",
|
|
"# Load dataset from the hub\n",
|
|
"dataset = load_dataset(\"philschmid/gretel-synthetic-text-to-sql\", split=\"train\")\n",
|
|
"dataset = dataset.shuffle().select(range(12500))\n",
|
|
"\n",
|
|
"# Convert dataset to OAI messages\n",
|
|
"dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)\n",
|
|
"# split dataset into 80% training samples and 20% test samples\n",
|
|
"dataset = dataset.train_test_split(test_size=0.2)\n",
|
|
"\n",
|
|
"# Print formatted user prompt\n",
|
|
"for item in dataset[\"train\"][0][\"messages\"]:\n",
|
|
" print(item)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c0eb2e06"
|
|
},
|
|
"source": [
|
|
"## Fine-tune Gemma using TRL and the SFTTrainer\n",
|
|
"\n",
|
|
"You are now ready to fine-tune your model. Hugging Face TRL [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) makes it straightforward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additional quality of life features, including:\n",
|
|
"\n",
|
|
"* Dataset formatting, including conversational and instruction formats\n",
|
|
"* Training on completions only, ignoring prompts\n",
|
|
"* Packing datasets for more efficient training\n",
|
|
"* Parameter-efficient fine-tuning (PEFT) support including QloRA\n",
|
|
"* Preparing the model and tokenizer for conversational fine-tuning (such as adding special tokens)\n",
|
|
"\n",
|
|
"The following code loads the Gemma model and tokenizer from Hugging Face and initializes the quantization configuration."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "18069ed2"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "0b17e7e80e884df59a0bea8b6f6802e9",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"config.json: 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "f5cfbb54cfec4e7d93ed2eb0d5b2e62a",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"model.safetensors: 0%| | 0.00/10.2G [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a5f8ae73ccd3478985fbc37e95b89de8",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Loading weights: 0%| | 0/2011 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "d9a1c13e560c4790b626ab3fd045e1b0",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"generation_config.json: 0%| | 0.00/181 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "cb68212e51dc480d99a66d131838858e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"config.json: 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "b4a245124cc74c4db7b6ad73a1b65f33",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a464b2885d6649b586c73e74fcca0f07",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"tokenizer.json: 0%| | 0.00/32.2M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a6525a790f4440ff989d5c815dd94da7",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"chat_template.jinja: 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from transformers import AutoTokenizer, AutoModelForImageTextToText, BitsAndBytesConfig\n",
|
|
"\n",
|
|
"# Hugging Face model id\n",
|
|
"model_id = \"google/gemma-4-E2B\" # @param [\"google/gemma-4-E2B\",\"google/gemma-4-E4B\"] {\"allow-input\":true}\n",
|
|
"\n",
|
|
"# Check if GPU benefits from bfloat16\n",
|
|
"if torch.cuda.get_device_capability()[0] >= 8:\n",
|
|
" torch_dtype = torch.bfloat16\n",
|
|
"else:\n",
|
|
" torch_dtype = torch.float16\n",
|
|
"\n",
|
|
"# Define model init arguments\n",
|
|
"model_kwargs = dict(\n",
|
|
" dtype=torch_dtype,\n",
|
|
" device_map=\"auto\", # Let torch decide how to load the model\n",
|
|
")\n",
|
|
"\n",
|
|
"# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage\n",
|
|
"model_kwargs[\"quantization_config\"] = BitsAndBytesConfig(\n",
|
|
" load_in_4bit=True,\n",
|
|
" bnb_4bit_use_double_quant=True,\n",
|
|
" bnb_4bit_quant_type='nf4',\n",
|
|
" bnb_4bit_compute_dtype=model_kwargs['dtype'],\n",
|
|
" bnb_4bit_quant_storage=model_kwargs['dtype'],\n",
|
|
")\n",
|
|
"\n",
|
|
"# Load model and tokenizer\n",
|
|
"model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-4-E2B-it\") # Load the Instruction Tokenizer to use the official Gemma template"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "37ec1d1b"
|
|
},
|
|
"source": [
|
|
"The `SFTTrainer` supports a built-in integration with `peft`, which makes it straightforward to efficiently tune LLMs using QLoRA. You only need to create a `LoraConfig` and provide it to the trainer."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ed00e846"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from peft import LoraConfig\n",
|
|
"\n",
|
|
"peft_config = LoraConfig(\n",
|
|
" lora_alpha=16,\n",
|
|
" lora_dropout=0.05,\n",
|
|
" r=16,\n",
|
|
" bias=\"none\",\n",
|
|
" target_modules=\"all-linear\",\n",
|
|
" task_type=\"CAUSAL_LM\",\n",
|
|
" modules_to_save=[\"lm_head\", \"embed_tokens\"], # make sure to save the lm_head and embed_tokens as you train the special tokens\n",
|
|
" ensure_weight_tying=True,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bbd9fc1b"
|
|
},
|
|
"source": [
|
|
"Before you can start your training, you need to define the hyperparameter you want to use in a `SFTConfig` instance."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "989be3c1"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from trl import SFTConfig\n",
|
|
"\n",
|
|
"args = SFTConfig(\n",
|
|
" output_dir=\"gemma-text-to-sql\", # directory to save and repository id\n",
|
|
" max_length=512, # max length for model and packing of the dataset\n",
|
|
" num_train_epochs=3, # number of training epochs\n",
|
|
" per_device_train_batch_size=1, # batch size per device during training\n",
|
|
" optim=\"adamw_torch_fused\", # use fused adamw optimizer\n",
|
|
" logging_steps=10, # log every 10 steps\n",
|
|
" save_strategy=\"epoch\", # save checkpoint every epoch\n",
|
|
" eval_strategy=\"epoch\", # evaluate checkpoint every epoch\n",
|
|
" learning_rate=5e-5, # learning rate\n",
|
|
" fp16=True if model.dtype == torch.float16 else False, # use float16 precision\n",
|
|
" bf16=True if model.dtype == torch.bfloat16 else False, # use bfloat16 precision\n",
|
|
" max_grad_norm=0.3, # max gradient norm based on QLoRA paper\n",
|
|
" lr_scheduler_type=\"constant\", # use constant learning rate scheduler\n",
|
|
" push_to_hub=True, # push model to hub\n",
|
|
" report_to=\"tensorboard\", # report metrics to tensorboard\n",
|
|
" dataset_kwargs={\n",
|
|
" \"add_special_tokens\": False, # Template with special tokens\n",
|
|
" \"append_concat_token\": True, # Add EOS token as separator token between examples\n",
|
|
" }\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "dd88e798"
|
|
},
|
|
"source": [
|
|
"You now have every building block you need to create your `SFTTrainer` to start the training of your model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ade95df7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "9061644033864e22a5cd8905051b6637",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Tokenizing train dataset: 0%| | 0/10000 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "f63804866860487cb9135f5729d76f01",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Tokenizing eval dataset: 0%| | 0/2500 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from trl import SFTTrainer\n",
|
|
"\n",
|
|
"# Create Trainer object\n",
|
|
"trainer = SFTTrainer(\n",
|
|
" model=model,\n",
|
|
" args=args,\n",
|
|
" train_dataset=dataset[\"train\"],\n",
|
|
" eval_dataset=dataset[\"test\"],\n",
|
|
" peft_config=peft_config,\n",
|
|
" processing_class=tokenizer,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "fad61a6a"
|
|
},
|
|
"source": [
|
|
"Start training by calling the `train()` method."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "995e7e38"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"\n",
|
|
" <div>\n",
|
|
" \n",
|
|
" <progress value='1875' max='1875' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
|
" [1875/1875 28:32, Epoch 3/3]\n",
|
|
" </div>\n",
|
|
" <table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>Epoch</th>\n",
|
|
" <th>Training Loss</th>\n",
|
|
" <th>Validation Loss</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0.536652</td>\n",
|
|
" <td>0.530056</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>0.430735</td>\n",
|
|
" <td>0.464053</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>0.386358</td>\n",
|
|
" <td>0.443147</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table><p>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Start training, the model will be automatically saved to the Hub and the output directory\n",
|
|
"trainer.train()\n",
|
|
"\n",
|
|
"# Save the final model again to the Hugging Face Hub\n",
|
|
"trainer.save_model()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "b47b9733"
|
|
},
|
|
"source": [
|
|
"Before you can test your model, make sure to free the memory."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "40a32ed7"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# free the memory again\n",
|
|
"del model\n",
|
|
"del trainer\n",
|
|
"torch.cuda.empty_cache()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "862e9728"
|
|
},
|
|
"source": [
|
|
"When using QLoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with serving stacks like vLLM or TGI, you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This saves a default model, which can be used for inference.\n",
|
|
"\n",
|
|
"Note: It requires more than 30GB of CPU Memory when you want to merge the adapter into the model. You can skip this and continue with Test Model Inference."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "761e324b"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "b58cae40ed3d40d89be8b4065548a69d",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Loading weights: 0%| | 0/2011 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "2d9db55b847041a5a3b446001239202a",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Writing model shards: 0%| | 0/5 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('merged_model/tokenizer_config.json',\n",
|
|
" 'merged_model/chat_template.jinja',\n",
|
|
" 'merged_model/tokenizer.json')"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from peft import PeftModel\n",
|
|
"\n",
|
|
"# Load Model base model\n",
|
|
"model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)\n",
|
|
"\n",
|
|
"# Merge LoRA and base model and save\n",
|
|
"peft_model = PeftModel.from_pretrained(model, args.output_dir)\n",
|
|
"merged_model = peft_model.merge_and_unload()\n",
|
|
"merged_model.save_pretrained(\"merged_model\", safe_serialization=True, max_shard_size=\"2GB\")\n",
|
|
"\n",
|
|
"processor = AutoTokenizer.from_pretrained(args.output_dir)\n",
|
|
"processor.save_pretrained(\"merged_model\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bf86e31d"
|
|
},
|
|
"source": [
|
|
"## Test Model Inference and generate SQL queries\n",
|
|
"\n",
|
|
"After the training is done, you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.\n",
|
|
"\n",
|
|
"Note: Evaluating generative AI models is not a trivial task since one input can have multiple correct outputs. This guide only focuses on manual evaluation and vibe checks."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "aab1c5c5"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "224c4db7e94445d9adb369eeac3c0bd2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Loading weights: 0%| | 0/2012 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from transformers import pipeline\n",
|
|
"\n",
|
|
"model_id = \"merged_model\"\n",
|
|
"\n",
|
|
"# Load Model with PEFT adapter\n",
|
|
"model = AutoModelForImageTextToText.from_pretrained(\n",
|
|
" model_id,\n",
|
|
" device_map=\"auto\",\n",
|
|
" dtype=\"auto\",\n",
|
|
")\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(model_id)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3dccb57c"
|
|
},
|
|
"source": [
|
|
"Let's load a random sample from the test dataset and generate a SQL command."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "1fd887f4"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"<bos><|turn>system\n",
|
|
"You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.<turn|>\n",
|
|
"<|turn>user\n",
|
|
"Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n",
|
|
"\n",
|
|
"<SCHEMA>\n",
|
|
"CREATE TABLE broadband_plans (plan_id INT, plan_name VARCHAR(255), download_speed INT, upload_speed INT, price DECIMAL(5,2));\n",
|
|
"</SCHEMA>\n",
|
|
"\n",
|
|
"<USER_QUERY>\n",
|
|
"Delete a broadband plan from the 'broadband_plans' table\n",
|
|
"</USER_QUERY><turn|>\n",
|
|
"<|turn>model\n",
|
|
"\n",
|
|
"Context:\n",
|
|
" CREATE TABLE broadband_plans (plan_id INT, plan_name VARCHAR(255), download_speed INT, upload_speed INT, price DECIMAL(5,2));\n",
|
|
"Query:\n",
|
|
" Delete a broadband plan from the 'broadband_plans' table\n",
|
|
"Original Answer:\n",
|
|
"DELETE FROM broadband_plans WHERE plan_id = 3001;\n",
|
|
"Generated Answer:\n",
|
|
"DELETE FROM broadband_plans\n",
|
|
"WHERE plan_name = 'Basic';\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from random import randint\n",
|
|
"import re\n",
|
|
"from transformers import pipeline, GenerationConfig\n",
|
|
"\n",
|
|
"config = GenerationConfig.from_pretrained(model_id)\n",
|
|
"config.max_new_tokens = 256\n",
|
|
"\n",
|
|
"# Load the model and tokenizer into the pipeline\n",
|
|
"pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
|
|
"\n",
|
|
"# Load a random sample from the test dataset\n",
|
|
"rand_idx = randint(0, len(dataset[\"test\"]))\n",
|
|
"test_sample = dataset[\"test\"][rand_idx]\n",
|
|
"\n",
|
|
"# Convert as test example into a prompt with the Gemma template\n",
|
|
"prompt = pipe.tokenizer.apply_chat_template(test_sample[\"messages\"][:2], tokenize=False, add_generation_prompt=True)\n",
|
|
"print(prompt)\n",
|
|
"\n",
|
|
"# Generate our SQL query.\n",
|
|
"outputs = pipe(prompt, generation_config=config)\n",
|
|
"\n",
|
|
"# Extract the user query and original answer\n",
|
|
"print(f\"Context:\\n\", re.search(r'<SCHEMA>\\n(.*?)\\n</SCHEMA>', test_sample['messages'][1]['content'], re.DOTALL).group(1).strip())\n",
|
|
"print(f\"Query:\\n\", re.search(r'<USER_QUERY>\\n(.*?)\\n</USER_QUERY>', test_sample['messages'][1]['content'], re.DOTALL).group(1).strip())\n",
|
|
"print(f\"Original Answer:\\n{test_sample['messages'][2]['content']}\")\n",
|
|
"print(f\"Generated Answer:\\n{outputs[0]['generated_text'][len(prompt):].strip()}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "6f8ff452"
|
|
},
|
|
"source": [
|
|
"## Summary and next steps\n",
|
|
"\n",
|
|
"This tutorial covered how to fine-tune a Gemma model using TRL and QLoRA. Check out the following docs next:\n",
|
|
"\n",
|
|
"* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
|
|
"* Learn how to [fine-tune Gemma for vision tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora).\n",
|
|
"* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
|
|
"* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
|
|
"* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"name": "huggingface_text_finetune_qlora.ipynb",
|
|
"toc_visible": true
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|