docs: add canonical tooling corpus (147 files) from Google/HF/frameworks

Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mortdecai
2026-04-18 12:24:48 -04:00
parent 5011059f5d
commit eecebe7ef5
149 changed files with 181297 additions and 0 deletions
@@ -0,0 +1,186 @@
# Gemma in PyTorch
**Gemma** is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They include both text-only and multimodal decoder-only large language models, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links:
* [Gemma on Google AI](https://ai.google.dev/gemma)
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma-3)
* [Gemma on Vertex AI Model Garden](https://pantheon.corp.google.com/vertex-ai/publishers/google/model-garden/gemma3)
This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU.
## Updates
* [March 12th, 2025 🔥] Support Gemma v3. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-3/pytorch) and [Hugging Face](https://huggingface.co/models?other=gemma_torch)
* [June 26th, 2024] Support Gemma v2. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-2/pytorch) and Hugging Face
* [April 9th, 2024] Support CodeGemma. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/codegemma/pytorch) and [Hugging Face](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11)
* [April 5, 2024] Support Gemma v1.1. You can find the v1.1 checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch) and [Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
## Download Gemma model checkpoint
You can find the model checkpoints on Kaggle:
- [Gemma 3](https://www.kaggle.com/models/google/gemma-3/pyTorch)
- [Gemma 2](https://www.kaggle.com/models/google/gemma-2/pyTorch)
- [Gemma](https://www.kaggle.com/models/google/gemma/pyTorch)
Alternatively, you can find the model checkpoints on the Hugging Face Hub [here](https://huggingface.co/models?other=gemma_torch). To download the models, go the the model repository of the model of interest and click the `Files and versions` tab, and download the model and tokenizer files. For programmatic downloading, if you have `huggingface_hub` installed, you can also run:
```
huggingface-cli download google/gemma-3-4b-it-pytorch
```
The following model sizes are available:
- **Gemma 3**:
- **Text only**: 1b
- **Multimodal**: 4b, 12b, 27b_v3
- **Gemma 2**:
- **Text only**: 2b-v2, 9b, 27b
- **Gemma**:
- **Text only**: 2b, 7b
Note that you can choose between the 1B, 4B, 12B, and 27B variants.
```
VARIANT=<1b, 2b, 2b-v2, 4b, 7b, 9b, 12b, 27b, 27b_v3>
CKPT_PATH=<Insert ckpt path here>
```
## Try it free on Colab
Follow the steps at
[https://ai.google.dev/gemma/docs/pytorch_gemma](https://ai.google.dev/gemma/docs/pytorch_gemma).
## Try it out with PyTorch
Prerequisite: make sure you have setup docker permission properly as a non-root user.
```bash
sudo usermod -aG docker $USER
newgrp docker
```
### Build the docker image.
```bash
DOCKER_URI=gemma:${USER}
docker build -f docker/Dockerfile ./ -t ${DOCKER_URI}
```
### Run Gemma inference on CPU.
> NOTE: This is a multimodal example. Use a multimodal variant.
```bash
docker run -t --rm \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_multimodal.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Run Gemma inference on GPU.
> NOTE: This is a multimodal example. Use a multimodal variant.
```bash
docker run -t --rm \
--gpus all \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_multimodal.py \
--device=cuda \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}"
# add `--quant` for the int8 quantized model.
```
## Try It out with PyTorch/XLA
### Build the docker image (CPU, TPU).
```bash
DOCKER_URI=gemma_xla:${USER}
docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}
```
### Build the docker image (GPU).
```bash
DOCKER_URI=gemma_xla_gpu:${USER}
docker build -f docker/xla_gpu.Dockerfile ./ -t ${DOCKER_URI}
```
### Run Gemma inference on CPU.
> NOTE: This is a multimodal example. Use a multimodal variant.
```bash
docker run -t --rm \
--shm-size 4gb \
-e PJRT_DEVICE=CPU \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_xla.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Run Gemma inference on TPU.
Note: be sure to use the docker container built from `xla.Dockerfile`.
```bash
docker run -t --rm \
--shm-size 4gb \
-e PJRT_DEVICE=TPU \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_xla.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Run Gemma inference on GPU.
Note: be sure to use the docker container built from `xla_gpu.Dockerfile`.
```bash
docker run -t --rm --privileged \
--shm-size=16g --net=host --gpus all \
-e USE_CUDA=1 \
-e PJRT_DEVICE=CUDA \
-v ${CKPT_PATH}:/tmp/ckpt \
${DOCKER_URI} \
python scripts/run_xla.py \
--ckpt=/tmp/ckpt \
--variant="${VARIANT}" \
# add `--quant` for the int8 quantized model.
```
### Tokenizer Notes
99 unused tokens are reserved in the pretrained tokenizer model to assist with more efficient training/fine-tuning. Unused tokens are in the string format of `<unused[0-97]>` with token id range of `[7-104]`.
```
"<unused0>": 7,
"<unused1>": 8,
"<unused2>": 9,
...
"<unused98>": 104,
```
## Disclaimer
This is not an officially supported Google product.
@@ -0,0 +1,107 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import random
import numpy as np
import torch
from absl import app, flags
from gemma import config
from gemma import model as gemma_model
# Define flags
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True)
flags.DEFINE_string('variant', '4b', 'Model variant.')
flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.')
flags.DEFINE_integer('seed', 12345, 'Random seed.')
flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
flags.DEFINE_string('prompt', 'What are large language models?', 'Input prompt for the model.')
# Define valid text only model variants
_VALID_MODEL_VARIANTS = ['2b', '2b-v2', '7b', '9b', '27b', '1b']
# Define valid devices
_VALID_DEVICES = ['cpu', 'cuda']
# Validator function for the 'variant' flag
def validate_variant(variant):
if variant not in _VALID_MODEL_VARIANTS:
raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}')
return True
# Validator function for the 'device' flag
def validate_device(device):
if device not in _VALID_DEVICES:
raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}')
return True
# Register the validator for the 'variant' flag
flags.register_validator('variant', validate_variant, message='Invalid model variant.')
# Register the validator for the 'device' flag
flags.register_validator('device', validate_device, message='Invalid device.')
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def main(_):
# Construct the model config.
model_config = config.get_model_config(FLAGS.variant)
model_config.dtype = "float32"
model_config.quant = FLAGS.quant
# Seed random.
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
torch.manual_seed(FLAGS.seed)
# Create the model and load the weights.
device = torch.device(FLAGS.device)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(FLAGS.ckpt)
model = model.to(device).eval()
print("Model loading done")
# Generate the response.
result = model.generate(FLAGS.prompt, device, output_len=FLAGS.output_len)
# Print the prompts and results.
print('======================================')
print(f'PROMPT: {FLAGS.prompt}')
print(f'RESULT: {result}')
print('======================================')
if __name__ == "__main__":
app.run(main)
# How to run this script:
# Example command (replace with your actual paths and values):
# python scripts/run.py --device=cpu --ckpt=/path/to/your/pytorch_checkpoint/model.ckpt --output_len=2 --prompt="The name of the capital of Italy is"
# Important:
# - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file.
# - Choose the correct --variant (model size).
# - Use --device=cuda if you have a GPU; otherwise, use --device=cpu.
@@ -0,0 +1,197 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import random
from absl import app
from absl import flags
import numpy as np
from PIL import Image
import torch
from gemma import config
from gemma import gemma3_model
# Define flags
FLAGS = flags.FLAGS
_CKPT = flags.DEFINE_string(
'ckpt', None, 'Path to the checkpoint file.', required=True
)
_VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.')
_DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
_OUTPUT_LEN = flags.DEFINE_integer(
'output_len', 10, 'Length of the output sequence.'
)
_SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.')
_QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
# Define valid multimodal model variants
_VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3']
# Define valid devices
_VALID_DEVICES = ['cpu', 'cuda']
# Validator function for the 'variant' flag
def validate_variant(variant):
if variant not in _VALID_MODEL_VARIANTS:
raise ValueError(
f'Invalid variant: {variant}. Valid variants are:'
f' {_VALID_MODEL_VARIANTS}'
)
return True
# Validator function for the 'device' flag
def validate_device(device):
if device not in _VALID_DEVICES:
raise ValueError(
f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}'
)
return True
# Register the validator for the 'variant' flag
flags.register_validator(
'variant', validate_variant, message='Invalid model variant.'
)
# Register the validator for the 'device' flag
flags.register_validator('device', validate_device, message='Invalid device.')
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def main(_):
# Construct the model config.
model_config = config.get_model_config(_VARIANT.value)
model_config.dtype = 'float32'
model_config.quant = _QUANT.value
image_paths = {"cow_in_beach": "scripts/images/cow_in_beach.jpg",
"lilly": "scripts/images/lilly.jpg",
"sunflower": "scripts/images/sunflower.JPG",
'golden_test_image': (
'scripts/images/test_image.jpg'
),
}
image = {}
for key in image_paths:
try:
image[key] = Image.open(image_paths[key]) # Open local file
image[key].show()
except IOError as e:
print(f"Error loading image: {e}")
exit()
# Seed random.
random.seed(_SEED.value)
np.random.seed(_SEED.value)
torch.manual_seed(_SEED.value)
# Create the model and load the weights.
device = torch.device(_DEVICE.value)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma3_model.Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(_CKPT.value)['model_state_dict'])
# model.load_weights(_CKPT.value)
model = model.to(device).eval()
print('Model loading done')
# Generate text only.
result = model.generate(
[
[
'<start_of_turn>user The capital of Italy'
' is?<end_of_turn>\n<start_of_turn>model'
],
[
'<start_of_turn>user What is your'
' purpose?<end_of_turn>\n<start_of_turn>model'
],
],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the results.
print('======================================')
print(f'Text only RESULT: {result}')
print('======================================')
# Generate golden Gemax test image.
result = model.generate(
[[
'<start_of_turn>user\n',
image['golden_test_image'],
'Caption this image. <end_of_turn>\n<start_of_turn>model',
]],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the result.
print('======================================')
print(f'Golden test image RESULT: {result}')
print('======================================')
# Generate text and image.
result = model.generate(
[[
'<start_of_turn>user\n',
image['cow_in_beach'],
(
'The name of the animal in the image is'
' <end_of_turn>\n<start_of_turn>model'
),
]],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the result.
print('======================================')
print(f'Single image RESULT: {result}')
print('======================================')
# Generate interleave text and multiple images.
result = model.generate(
[[
'<start_of_turn>user\nThis image',
image['lilly'],
'and this image',
image['sunflower'],
'are similar because? <end_of_turn>\n<start_of_turn>model',
]],
device,
output_len=_OUTPUT_LEN.value,
)
# Print the result.
print('======================================')
print(f'Interleave images RESULT: {result}')
print('======================================')
if __name__ == '__main__':
app.run(main)
@@ -0,0 +1,267 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import os
import random
import socket
import sys
from typing import List, Union
import numpy as np
import torch
import torch.multiprocessing
from gemma.config import GemmaConfig, get_model_config
from gemma.model_xla import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import gemma.xla_model_parallel as xla_model_parallel
USE_CUDA = os.environ.get('USE_CUDA', False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
# Choose an available port.
with contextlib.closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
MASTER_PORT = str(s.getsockname()[1])
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def generate(
i: int,
model_config: GemmaConfig,
ckpt_path: str,
prompts: List[str],
output_lens: List[int],
temperatures: Union[List[float], None],
top_ps: List[float],
top_ks: List[int],
seed: int
):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device("cuda", local_rank)
torch.cuda.set_device(local_rank)
else:
device = xm.xla_device()
xm.set_rng_state(seed, device)
rank = xla_model_parallel.get_model_parallel_rank()
world_size = xla_model_parallel.get_model_parallel_world_size()
if rank > 0:
sys.stdout = open(os.devnull, 'w')
# build, load and compile model.
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config, world_size, rank, device)
model.load_weights(ckpt_path)
model = model.to(device).eval()
# create tokenizer.
tokenizer = Tokenizer(model_config.tokenizer)
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
batch_size = len(prompts)
if temperatures is not None:
assert batch_size == len(temperatures)
assert batch_size == len(top_ps)
assert batch_size == len(top_ks)
max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
assert max_seq_len <= model_config.max_position_embeddings
if model_config.num_key_value_heads < world_size:
assert world_size % model_config.num_key_value_heads == 0
n_local_heads = 1
else:
assert model_config.num_key_value_heads % world_size == 0
n_local_heads = model_config.num_key_value_heads // world_size
# build KV caches
kv_caches = []
for _ in range(model_config.num_hidden_layers):
k_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
v_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
kv_caches.append((k_cache, v_cache))
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
tokenizer.pad_id,
dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
input_token_ids_tensor = input_token_ids_tensor.to(device)
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
temperatures_tensor = None if not temperatures else torch.FloatTensor(temperatures).to(device)
top_ps_tensor = torch.FloatTensor(top_ps).to(device)
top_ks_tensor = torch.LongTensor(top_ks).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
if not USE_CUDA:
xm.mark_step()
# Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids, _ = model(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
output_index = output_index + 1
if not USE_CUDA:
xm.mark_step()
# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
output_lens[i]]
if tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(tokenizer.decode(trimmed_output))
for prompt, result in zip(prompts, results):
print('======================================')
print(f'PROMPT: {prompt}')
print(f'RESULT: {result}')
print('======================================')
def main(args):
model_config = get_model_config(args.variant)
model_config.quant = args.quant
prompts = [args.prompt]
n = len(prompts)
output_lengths = [args.output_len] * n
temperatures = [0.95] * n
top_ps = [1.0] * n
top_ks = [100] * n
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
torch.multiprocessing.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
else:
xmp.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "2b-v2", "7b", "9b", "27b"])
parser.add_argument("--output_len", type=int, default=4)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()
main(args)