docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,186 @@
|
||||
# Gemma in PyTorch
|
||||
|
||||
**Gemma** is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They include both text-only and multimodal decoder-only large language models, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links:
|
||||
|
||||
* [Gemma on Google AI](https://ai.google.dev/gemma)
|
||||
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma-3)
|
||||
* [Gemma on Vertex AI Model Garden](https://pantheon.corp.google.com/vertex-ai/publishers/google/model-garden/gemma3)
|
||||
|
||||
This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU.
|
||||
|
||||
## Updates
|
||||
|
||||
* [March 12th, 2025 🔥] Support Gemma v3. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-3/pytorch) and [Hugging Face](https://huggingface.co/models?other=gemma_torch)
|
||||
|
||||
* [June 26th, 2024] Support Gemma v2. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-2/pytorch) and Hugging Face
|
||||
|
||||
* [April 9th, 2024] Support CodeGemma. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/codegemma/pytorch) and [Hugging Face](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11)
|
||||
|
||||
* [April 5, 2024] Support Gemma v1.1. You can find the v1.1 checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch) and [Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
|
||||
|
||||
## Download Gemma model checkpoint
|
||||
|
||||
You can find the model checkpoints on Kaggle:
|
||||
|
||||
- [Gemma 3](https://www.kaggle.com/models/google/gemma-3/pyTorch)
|
||||
- [Gemma 2](https://www.kaggle.com/models/google/gemma-2/pyTorch)
|
||||
- [Gemma](https://www.kaggle.com/models/google/gemma/pyTorch)
|
||||
|
||||
Alternatively, you can find the model checkpoints on the Hugging Face Hub [here](https://huggingface.co/models?other=gemma_torch). To download the models, go the the model repository of the model of interest and click the `Files and versions` tab, and download the model and tokenizer files. For programmatic downloading, if you have `huggingface_hub` installed, you can also run:
|
||||
|
||||
```
|
||||
huggingface-cli download google/gemma-3-4b-it-pytorch
|
||||
```
|
||||
|
||||
The following model sizes are available:
|
||||
|
||||
- **Gemma 3**:
|
||||
- **Text only**: 1b
|
||||
- **Multimodal**: 4b, 12b, 27b_v3
|
||||
- **Gemma 2**:
|
||||
- **Text only**: 2b-v2, 9b, 27b
|
||||
- **Gemma**:
|
||||
- **Text only**: 2b, 7b
|
||||
|
||||
|
||||
Note that you can choose between the 1B, 4B, 12B, and 27B variants.
|
||||
|
||||
```
|
||||
VARIANT=<1b, 2b, 2b-v2, 4b, 7b, 9b, 12b, 27b, 27b_v3>
|
||||
CKPT_PATH=<Insert ckpt path here>
|
||||
```
|
||||
|
||||
## Try it free on Colab
|
||||
|
||||
Follow the steps at
|
||||
[https://ai.google.dev/gemma/docs/pytorch_gemma](https://ai.google.dev/gemma/docs/pytorch_gemma).
|
||||
|
||||
## Try it out with PyTorch
|
||||
|
||||
Prerequisite: make sure you have setup docker permission properly as a non-root user.
|
||||
|
||||
```bash
|
||||
sudo usermod -aG docker $USER
|
||||
newgrp docker
|
||||
```
|
||||
|
||||
### Build the docker image.
|
||||
|
||||
```bash
|
||||
DOCKER_URI=gemma:${USER}
|
||||
|
||||
docker build -f docker/Dockerfile ./ -t ${DOCKER_URI}
|
||||
```
|
||||
|
||||
### Run Gemma inference on CPU.
|
||||
|
||||
> NOTE: This is a multimodal example. Use a multimodal variant.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_multimodal.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Run Gemma inference on GPU.
|
||||
|
||||
> NOTE: This is a multimodal example. Use a multimodal variant.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
--gpus all \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_multimodal.py \
|
||||
--device=cuda \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}"
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
## Try It out with PyTorch/XLA
|
||||
|
||||
### Build the docker image (CPU, TPU).
|
||||
|
||||
```bash
|
||||
DOCKER_URI=gemma_xla:${USER}
|
||||
|
||||
docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}
|
||||
```
|
||||
|
||||
### Build the docker image (GPU).
|
||||
|
||||
```bash
|
||||
DOCKER_URI=gemma_xla_gpu:${USER}
|
||||
|
||||
docker build -f docker/xla_gpu.Dockerfile ./ -t ${DOCKER_URI}
|
||||
```
|
||||
|
||||
### Run Gemma inference on CPU.
|
||||
|
||||
> NOTE: This is a multimodal example. Use a multimodal variant.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
--shm-size 4gb \
|
||||
-e PJRT_DEVICE=CPU \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_xla.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Run Gemma inference on TPU.
|
||||
|
||||
Note: be sure to use the docker container built from `xla.Dockerfile`.
|
||||
|
||||
```bash
|
||||
docker run -t --rm \
|
||||
--shm-size 4gb \
|
||||
-e PJRT_DEVICE=TPU \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_xla.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Run Gemma inference on GPU.
|
||||
|
||||
Note: be sure to use the docker container built from `xla_gpu.Dockerfile`.
|
||||
|
||||
```bash
|
||||
docker run -t --rm --privileged \
|
||||
--shm-size=16g --net=host --gpus all \
|
||||
-e USE_CUDA=1 \
|
||||
-e PJRT_DEVICE=CUDA \
|
||||
-v ${CKPT_PATH}:/tmp/ckpt \
|
||||
${DOCKER_URI} \
|
||||
python scripts/run_xla.py \
|
||||
--ckpt=/tmp/ckpt \
|
||||
--variant="${VARIANT}" \
|
||||
# add `--quant` for the int8 quantized model.
|
||||
```
|
||||
|
||||
### Tokenizer Notes
|
||||
|
||||
99 unused tokens are reserved in the pretrained tokenizer model to assist with more efficient training/fine-tuning. Unused tokens are in the string format of `<unused[0-97]>` with token id range of `[7-104]`.
|
||||
|
||||
```
|
||||
"<unused0>": 7,
|
||||
"<unused1>": 8,
|
||||
"<unused2>": 9,
|
||||
...
|
||||
"<unused98>": 104,
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an officially supported Google product.
|
||||
@@ -0,0 +1,107 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import contextlib
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from absl import app, flags
|
||||
|
||||
from gemma import config
|
||||
from gemma import model as gemma_model
|
||||
|
||||
# Define flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True)
|
||||
flags.DEFINE_string('variant', '4b', 'Model variant.')
|
||||
flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
|
||||
flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.')
|
||||
flags.DEFINE_integer('seed', 12345, 'Random seed.')
|
||||
flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
|
||||
flags.DEFINE_string('prompt', 'What are large language models?', 'Input prompt for the model.')
|
||||
|
||||
# Define valid text only model variants
|
||||
_VALID_MODEL_VARIANTS = ['2b', '2b-v2', '7b', '9b', '27b', '1b']
|
||||
|
||||
# Define valid devices
|
||||
_VALID_DEVICES = ['cpu', 'cuda']
|
||||
|
||||
# Validator function for the 'variant' flag
|
||||
def validate_variant(variant):
|
||||
if variant not in _VALID_MODEL_VARIANTS:
|
||||
raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}')
|
||||
return True
|
||||
|
||||
# Validator function for the 'device' flag
|
||||
def validate_device(device):
|
||||
if device not in _VALID_DEVICES:
|
||||
raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}')
|
||||
return True
|
||||
|
||||
# Register the validator for the 'variant' flag
|
||||
flags.register_validator('variant', validate_variant, message='Invalid model variant.')
|
||||
|
||||
# Register the validator for the 'device' flag
|
||||
flags.register_validator('device', validate_device, message='Invalid device.')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_tensor_type(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
def main(_):
|
||||
# Construct the model config.
|
||||
model_config = config.get_model_config(FLAGS.variant)
|
||||
model_config.dtype = "float32"
|
||||
model_config.quant = FLAGS.quant
|
||||
|
||||
# Seed random.
|
||||
random.seed(FLAGS.seed)
|
||||
np.random.seed(FLAGS.seed)
|
||||
torch.manual_seed(FLAGS.seed)
|
||||
|
||||
# Create the model and load the weights.
|
||||
device = torch.device(FLAGS.device)
|
||||
with _set_default_tensor_type(model_config.get_dtype()):
|
||||
model = gemma_model.GemmaForCausalLM(model_config)
|
||||
model.load_weights(FLAGS.ckpt)
|
||||
model = model.to(device).eval()
|
||||
print("Model loading done")
|
||||
|
||||
# Generate the response.
|
||||
result = model.generate(FLAGS.prompt, device, output_len=FLAGS.output_len)
|
||||
|
||||
# Print the prompts and results.
|
||||
print('======================================')
|
||||
print(f'PROMPT: {FLAGS.prompt}')
|
||||
print(f'RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
||||
|
||||
# How to run this script:
|
||||
|
||||
# Example command (replace with your actual paths and values):
|
||||
# python scripts/run.py --device=cpu --ckpt=/path/to/your/pytorch_checkpoint/model.ckpt --output_len=2 --prompt="The name of the capital of Italy is"
|
||||
# Important:
|
||||
# - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file.
|
||||
# - Choose the correct --variant (model size).
|
||||
# - Use --device=cuda if you have a GPU; otherwise, use --device=cpu.
|
||||
@@ -0,0 +1,197 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import contextlib
|
||||
import random
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from gemma import config
|
||||
from gemma import gemma3_model
|
||||
|
||||
# Define flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_CKPT = flags.DEFINE_string(
|
||||
'ckpt', None, 'Path to the checkpoint file.', required=True
|
||||
)
|
||||
_VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.')
|
||||
_DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
|
||||
_OUTPUT_LEN = flags.DEFINE_integer(
|
||||
'output_len', 10, 'Length of the output sequence.'
|
||||
)
|
||||
_SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.')
|
||||
_QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
|
||||
|
||||
# Define valid multimodal model variants
|
||||
_VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3']
|
||||
|
||||
# Define valid devices
|
||||
_VALID_DEVICES = ['cpu', 'cuda']
|
||||
|
||||
|
||||
# Validator function for the 'variant' flag
|
||||
def validate_variant(variant):
|
||||
if variant not in _VALID_MODEL_VARIANTS:
|
||||
raise ValueError(
|
||||
f'Invalid variant: {variant}. Valid variants are:'
|
||||
f' {_VALID_MODEL_VARIANTS}'
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# Validator function for the 'device' flag
|
||||
def validate_device(device):
|
||||
if device not in _VALID_DEVICES:
|
||||
raise ValueError(
|
||||
f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}'
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# Register the validator for the 'variant' flag
|
||||
flags.register_validator(
|
||||
'variant', validate_variant, message='Invalid model variant.'
|
||||
)
|
||||
|
||||
# Register the validator for the 'device' flag
|
||||
flags.register_validator('device', validate_device, message='Invalid device.')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_tensor_type(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
|
||||
def main(_):
|
||||
# Construct the model config.
|
||||
model_config = config.get_model_config(_VARIANT.value)
|
||||
model_config.dtype = 'float32'
|
||||
model_config.quant = _QUANT.value
|
||||
image_paths = {"cow_in_beach": "scripts/images/cow_in_beach.jpg",
|
||||
"lilly": "scripts/images/lilly.jpg",
|
||||
"sunflower": "scripts/images/sunflower.JPG",
|
||||
'golden_test_image': (
|
||||
'scripts/images/test_image.jpg'
|
||||
),
|
||||
}
|
||||
|
||||
image = {}
|
||||
for key in image_paths:
|
||||
try:
|
||||
image[key] = Image.open(image_paths[key]) # Open local file
|
||||
image[key].show()
|
||||
except IOError as e:
|
||||
print(f"Error loading image: {e}")
|
||||
exit()
|
||||
|
||||
# Seed random.
|
||||
random.seed(_SEED.value)
|
||||
np.random.seed(_SEED.value)
|
||||
torch.manual_seed(_SEED.value)
|
||||
|
||||
# Create the model and load the weights.
|
||||
device = torch.device(_DEVICE.value)
|
||||
with _set_default_tensor_type(model_config.get_dtype()):
|
||||
model = gemma3_model.Gemma3ForMultimodalLM(model_config)
|
||||
model.load_state_dict(torch.load(_CKPT.value)['model_state_dict'])
|
||||
# model.load_weights(_CKPT.value)
|
||||
model = model.to(device).eval()
|
||||
print('Model loading done')
|
||||
|
||||
# Generate text only.
|
||||
result = model.generate(
|
||||
[
|
||||
[
|
||||
'<start_of_turn>user The capital of Italy'
|
||||
' is?<end_of_turn>\n<start_of_turn>model'
|
||||
],
|
||||
[
|
||||
'<start_of_turn>user What is your'
|
||||
' purpose?<end_of_turn>\n<start_of_turn>model'
|
||||
],
|
||||
],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the results.
|
||||
print('======================================')
|
||||
print(f'Text only RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
# Generate golden Gemax test image.
|
||||
result = model.generate(
|
||||
[[
|
||||
'<start_of_turn>user\n',
|
||||
image['golden_test_image'],
|
||||
'Caption this image. <end_of_turn>\n<start_of_turn>model',
|
||||
]],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the result.
|
||||
print('======================================')
|
||||
print(f'Golden test image RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
# Generate text and image.
|
||||
result = model.generate(
|
||||
[[
|
||||
'<start_of_turn>user\n',
|
||||
image['cow_in_beach'],
|
||||
(
|
||||
'The name of the animal in the image is'
|
||||
' <end_of_turn>\n<start_of_turn>model'
|
||||
),
|
||||
]],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the result.
|
||||
print('======================================')
|
||||
print(f'Single image RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
# Generate interleave text and multiple images.
|
||||
result = model.generate(
|
||||
[[
|
||||
'<start_of_turn>user\nThis image',
|
||||
image['lilly'],
|
||||
'and this image',
|
||||
image['sunflower'],
|
||||
'are similar because? <end_of_turn>\n<start_of_turn>model',
|
||||
]],
|
||||
device,
|
||||
output_len=_OUTPUT_LEN.value,
|
||||
)
|
||||
|
||||
# Print the result.
|
||||
print('======================================')
|
||||
print(f'Interleave images RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
|
||||
from gemma.config import GemmaConfig, get_model_config
|
||||
from gemma.model_xla import GemmaForCausalLM
|
||||
from gemma.tokenizer import Tokenizer
|
||||
import gemma.xla_model_parallel as xla_model_parallel
|
||||
|
||||
USE_CUDA = os.environ.get('USE_CUDA', False)
|
||||
if not USE_CUDA:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
else:
|
||||
# Choose an available port.
|
||||
with contextlib.closing(socket.socket(socket.AF_INET,
|
||||
socket.SOCK_STREAM)) as s:
|
||||
s.bind(('', 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
MASTER_PORT = str(s.getsockname()[1])
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_tensor_type(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
|
||||
def generate(
|
||||
i: int,
|
||||
model_config: GemmaConfig,
|
||||
ckpt_path: str,
|
||||
prompts: List[str],
|
||||
output_lens: List[int],
|
||||
temperatures: Union[List[float], None],
|
||||
top_ps: List[float],
|
||||
top_ks: List[int],
|
||||
seed: int
|
||||
):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if USE_CUDA:
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = MASTER_PORT
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
"nccl",
|
||||
rank=int(os.environ.get("RANK", 0)),
|
||||
world_size=int(os.environ.get("WORLD_SIZE", 1)))
|
||||
xla_model_parallel.set_g_group()
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
device = xm.xla_device()
|
||||
xm.set_rng_state(seed, device)
|
||||
|
||||
rank = xla_model_parallel.get_model_parallel_rank()
|
||||
world_size = xla_model_parallel.get_model_parallel_world_size()
|
||||
if rank > 0:
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
|
||||
# build, load and compile model.
|
||||
with _set_default_tensor_type(model_config.get_dtype()):
|
||||
model = GemmaForCausalLM(model_config, world_size, rank, device)
|
||||
model.load_weights(ckpt_path)
|
||||
model = model.to(device).eval()
|
||||
|
||||
# create tokenizer.
|
||||
tokenizer = Tokenizer(model_config.tokenizer)
|
||||
|
||||
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
|
||||
min_prompt_len = min(len(p) for p in prompt_tokens)
|
||||
|
||||
batch_size = len(prompts)
|
||||
if temperatures is not None:
|
||||
assert batch_size == len(temperatures)
|
||||
assert batch_size == len(top_ps)
|
||||
assert batch_size == len(top_ks)
|
||||
max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
|
||||
assert max_seq_len <= model_config.max_position_embeddings
|
||||
if model_config.num_key_value_heads < world_size:
|
||||
assert world_size % model_config.num_key_value_heads == 0
|
||||
n_local_heads = 1
|
||||
else:
|
||||
assert model_config.num_key_value_heads % world_size == 0
|
||||
n_local_heads = model_config.num_key_value_heads // world_size
|
||||
|
||||
# build KV caches
|
||||
kv_caches = []
|
||||
for _ in range(model_config.num_hidden_layers):
|
||||
k_cache = torch.zeros(
|
||||
size=(batch_size, max_seq_len, n_local_heads,
|
||||
model_config.head_dim),
|
||||
dtype=model_config.get_dtype(),
|
||||
device=device,
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
size=(batch_size, max_seq_len, n_local_heads,
|
||||
model_config.head_dim),
|
||||
dtype=model_config.get_dtype(),
|
||||
device=device,
|
||||
)
|
||||
kv_caches.append((k_cache, v_cache))
|
||||
|
||||
# prepare inputs
|
||||
token_ids_tensor = torch.full((batch_size, max_seq_len),
|
||||
tokenizer.pad_id,
|
||||
dtype=torch.int64)
|
||||
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
|
||||
tokenizer.pad_id,
|
||||
dtype=torch.int64)
|
||||
for i, p in enumerate(prompt_tokens):
|
||||
token_ids_tensor[i, :len(p)] = torch.tensor(p)
|
||||
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
|
||||
p[:min_prompt_len])
|
||||
token_ids_tensor = token_ids_tensor.to(device)
|
||||
prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
|
||||
input_token_ids_tensor = input_token_ids_tensor.to(device)
|
||||
input_positions_tensor = torch.arange(0, min_prompt_len,
|
||||
dtype=torch.int64).to(device)
|
||||
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
|
||||
-2.3819763e38).to(torch.float)
|
||||
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
|
||||
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
|
||||
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
|
||||
temperatures_tensor = None if not temperatures else torch.FloatTensor(temperatures).to(device)
|
||||
top_ps_tensor = torch.FloatTensor(top_ps).to(device)
|
||||
top_ks_tensor = torch.LongTensor(top_ks).to(device)
|
||||
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
|
||||
if not USE_CUDA:
|
||||
xm.mark_step()
|
||||
|
||||
# Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
|
||||
for i in range(max_seq_len - min_prompt_len):
|
||||
next_token_ids, _ = model(
|
||||
input_token_ids=input_token_ids_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
kv_write_indices=None,
|
||||
kv_caches=kv_caches,
|
||||
mask=curr_mask_tensor,
|
||||
output_positions=output_positions_tensor,
|
||||
temperatures=temperatures_tensor,
|
||||
top_ps=top_ps_tensor,
|
||||
top_ks=top_ks_tensor,
|
||||
)
|
||||
curr_prompt_mask = prompt_mask_tensor.index_select(
|
||||
1, output_index).squeeze(dim=1)
|
||||
curr_token_ids = token_ids_tensor.index_select(
|
||||
1, output_index).squeeze(dim=1)
|
||||
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
|
||||
next_token_ids).unsqueeze(dim=1)
|
||||
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
|
||||
|
||||
input_token_ids_tensor = output_token_ids
|
||||
input_positions_tensor = output_index.unsqueeze(dim=-1)
|
||||
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
|
||||
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
|
||||
output_index = output_index + 1
|
||||
if not USE_CUDA:
|
||||
xm.mark_step()
|
||||
|
||||
# Detokenization.
|
||||
token_ids = token_ids_tensor.tolist()
|
||||
results = []
|
||||
for i, tokens in enumerate(token_ids):
|
||||
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
|
||||
output_lens[i]]
|
||||
if tokenizer.eos_id in trimmed_output:
|
||||
eos_index = trimmed_output.index(tokenizer.eos_id)
|
||||
trimmed_output = trimmed_output[:eos_index]
|
||||
results.append(tokenizer.decode(trimmed_output))
|
||||
|
||||
for prompt, result in zip(prompts, results):
|
||||
print('======================================')
|
||||
print(f'PROMPT: {prompt}')
|
||||
print(f'RESULT: {result}')
|
||||
print('======================================')
|
||||
|
||||
|
||||
def main(args):
|
||||
model_config = get_model_config(args.variant)
|
||||
model_config.quant = args.quant
|
||||
|
||||
prompts = [args.prompt]
|
||||
n = len(prompts)
|
||||
output_lengths = [args.output_len] * n
|
||||
temperatures = [0.95] * n
|
||||
top_ps = [1.0] * n
|
||||
top_ks = [100] * n
|
||||
|
||||
if USE_CUDA:
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = MASTER_PORT
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
"nccl",
|
||||
rank=int(os.environ.get("RANK", 0)),
|
||||
world_size=int(os.environ.get("WORLD_SIZE", 1)))
|
||||
xla_model_parallel.set_g_group()
|
||||
torch.multiprocessing.spawn(
|
||||
generate,
|
||||
args=(
|
||||
model_config,
|
||||
args.ckpt,
|
||||
prompts,
|
||||
output_lengths,
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
args.seed,
|
||||
),
|
||||
)
|
||||
else:
|
||||
xmp.spawn(
|
||||
generate,
|
||||
args=(
|
||||
model_config,
|
||||
args.ckpt,
|
||||
prompts,
|
||||
output_lengths,
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
args.seed,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt", type=str, required=True)
|
||||
parser.add_argument("--variant",
|
||||
type=str,
|
||||
default="2b",
|
||||
choices=["2b", "2b-v2", "7b", "9b", "27b"])
|
||||
parser.add_argument("--output_len", type=int, default=4)
|
||||
parser.add_argument("--seed", type=int, default=12345)
|
||||
parser.add_argument("--quant", action='store_true')
|
||||
parser.add_argument("--prompt", type=str, default="The meaning of life is")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user