{ "cells": [ { "metadata": { "id": "-KkvqLgjiIdD" }, "cell_type": "markdown", "source": [ "# Tool Use\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/tool_use.ipynb)\n", "\n", "Demo to show how to use tool-use with Gemma library.\n", "\n", "Note: The Gemma 1, 2 and 3 models were not specifically trained for tool use. This is more a proof-of-concept than an officially supported feature." ] }, { "metadata": { "id": "gcNRfVEnj4aq" }, "cell_type": "code", "source": [ "!pip install -q gemma" ], "outputs": [], "execution_count": null }, { "metadata": { "executionInfo": { "elapsed": 2221, "status": "ok", "timestamp": 1749202985345, "user": { "displayName": "", "userId": "" }, "user_tz": -120 }, "id": "k1ZAgLg1j9NT" }, "cell_type": "code", "source": [ "# Common imports\n", "import os\n", "import datetime\n", "\n", "# Gemma imports\n", "from gemma import gm" ], "outputs": [], "execution_count": 3 }, { "metadata": { "id": "139lZszJj_CC" }, "cell_type": "markdown", "source": [ "By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):" ] }, { "metadata": { "executionInfo": { "elapsed": 2, "status": "ok", "timestamp": 1749138071985, "user": { "displayName": "", "userId": "" }, "user_tz": -120 }, "id": "VtlWWLIYj_LJ" }, "cell_type": "code", "source": [ "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"" ], "outputs": [], "execution_count": 2 }, { "metadata": { "id": "31JPZb5RkD_p" }, "cell_type": "markdown", "source": [ "Load the model and the params." ] }, { "metadata": { "executionInfo": { "elapsed": 39057, "status": "ok", "timestamp": 1749203024713, "user": { "displayName": "", "userId": "" }, "user_tz": -120 }, "id": "RsAo6k4_kEJS", "outputId": "e10afb5c-6c81-42e8-e590-a39ea4ef3bf7" }, "cell_type": "code", "source": [ "model = gm.nn.Gemma3_4B()\n", "\n", "params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)" ], "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:2025-06-06 02:43:16,896:jax._src.xla_bridge:749: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'\n", "INFO:2025-06-06 02:43:16,897:jax._src.xla_bridge:749: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '://' (e.g., 'grpc://localhost'), but got \n", "INFO:2025-06-06 02:43:16,900:jax._src.xla_bridge:749: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'\n", "INFO:2025-06-06 02:43:16,901:jax._src.xla_bridge:749: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'\n" ] } ], "execution_count": 4 }, { "metadata": { "id": "p108c5yIlYH7" }, "cell_type": "markdown", "source": [ "## Using existing tools\n", "\n", "If you're familiar with the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial, using tool-use differ in two ways:\n", "\n", "1. Using the `gm.text.ToolSampler` rather than the `gm.text.ChatSampler`.\n", "2. Passing the `tools=` you want to use to the sampler.\n", "\n", "For example:" ] }, { "metadata": { "colab": { "height": 594 }, "executionInfo": { "elapsed": 50615, "status": "ok", "timestamp": 1749138791069, "user": { "displayName": "", "userId": "" }, "user_tz": -120 }, "id": "iRCV5h8BlVX6", "outputId": "b3b5d83d-8a8b-4982-fc8f-d409fb8b38a9" }, "cell_type": "code", "source": [ "sampler = gm.text.ToolSampler(\n", " model=model,\n", " params=params,\n", " tools=[\n", " gm.tools.Calculator(),\n", " gm.tools.FileExplorer(),\n", " ],\n", " print_stream=True,\n", ")\n", "\n", "output = sampler.chat('I have a serie `Sn+1 = cos(Sn) * 2`. Using the calculator, compute the steps 0-4 for S0 = 3')" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Thought: I need to compute S0, S1, S2, S3, and S4 using the given recursive formula Sn+1 = cos(Sn) * 2 and S0 = 3.\n", "Let's start with S0 = 3.\n", "S1 = cos(S0) * 2 = cos(3) * 2\n", "S2 = cos(S1) * 2 = cos(cos(3) * 2) * 2\n", "S3 = cos(S2) * 2 = cos(cos(cos(3) * 2) * 2) * 2\n", "S4 = cos(S3) * 2 = cos(cos(cos(cos(3) * 2) * 2)) * 2\n", "\n", "I will use the calculator to compute these values.\n", "{\"tool_name\": \"calculator\", \"expression\": \"cos(3) * 2\"}\n", "\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[Tool result: -1.9799849932008908]\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Thought: Now I need to compute S1 = cos(S0) * 2 = cos(3) * 2 = -1.9799849932008908 * 2\n", "{\"tool_name\": \"calculator\", \"expression\": \"-1.9799849932008908 * 2\"}\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[Tool result: -3.9599699864017817]\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Thought: Now I need to compute S2 = cos(S1) * 2 = cos(-3.9599699864017817) * 2\n", "{\"tool_name\": \"calculator\", \"expression\": \"cos(-3.9599699864017817) * 2\"}\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[Tool result: -1.3668134299076982]\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Thought: Now I need to compute S3 = cos(S2) * 2 = cos(-1.3668134299076982) * 2\n", "{\"tool_name\": \"calculator\", \"expression\": \"cos(-1.3668134299076982) * 2\"}\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[Tool result: 0.4051424976130353]\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Thought: Now I need to compute S4 = cos(S3) * 2 = cos(0.4051424976130353) * 2\n", "{\"tool_name\": \"calculator\", \"expression\": \"cos(0.4051424976130353) * 2\"}\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[Tool result: 1.8380924822033438]\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "The values are: S0 = 3, S1 = -3.9599699864017817, S2 = -1.3668134299076982, S3 = 0.4051424976130353, S4 = 1.8380924822033438" ] } ], "execution_count": 10 }, { "metadata": { "id": "FAI54F-Blkan" }, "cell_type": "markdown", "source": [ "Note: Only the final model answer is returned. You can access the conversation history, including all intermediates tool calls and output through `sampler.turns` property." ] }, { "metadata": { "id": "D0_IIS1Nlfuw" }, "cell_type": "markdown", "source": [ "## Creating your own tool\n", "\n", "To create your own tool, you can inherit from the `gm.tools.Tool` class. You should provide:\n", "\n", "* A description & example, so the model knows how to use your tool\n", "* Implement the `call` method. The `call` function can take arbitrary `**kwargs`, but the name of the args should match the ones defined in `tool_kwargs` and `tool_kwargs_doc`" ] }, { "metadata": { "executionInfo": { "elapsed": 55, "status": "ok", "timestamp": 1749203934196, "user": { "displayName": "", "userId": "" }, "user_tz": -120 }, "id": "XqmQcfdI0oEl" }, "cell_type": "code", "source": [ "class DateTime(gm.tools.Tool):\n", " \"\"\"Tool to access the current date.\"\"\"\n", "\n", " DESCRIPTION = 'Access the current date, time,...'\n", " EXAMPLE = gm.tools.Example(\n", " query='Which day of the week are we today ?',\n", " thought='The `datetime.strptime` uses %a for day of the week',\n", " tool_kwargs={'format': '%a'},\n", " tool_kwargs_doc={'format': ''},\n", " result='Sat',\n", " answer='Today is Saturday.',\n", " )\n", "\n", " def call(self, format: str) -> str:\n", " dt = datetime.datetime.now()\n", " return dt.strftime(format)\n" ], "outputs": [], "execution_count": 7 }, { "metadata": { "id": "sSxYhXPuuXYp" }, "cell_type": "markdown", "source": [ "The tool can then be used in the sampler:" ] }, { "metadata": { "colab": { "height": 118 }, "executionInfo": { "elapsed": 2156, "status": "ok", "timestamp": 1749204833094, "user": { "displayName": "", "userId": "" }, "user_tz": -120 }, "id": "9S8xB2B-0cbW", "outputId": "fccc0e89-e922-4184-8b77-800041cdd77e" }, "cell_type": "code", "source": [ "sampler = gm.text.ToolSampler(\n", " model=model,\n", " params=params,\n", " tools=[\n", " DateTime(),\n", " ],\n", " print_stream=True,\n", ")\n", "\n", "output = sampler.chat('Which date are we today ?')" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Thought: I need to get the current date.\n", "{\"tool_name\": \"datetime\", \"format\": \"%Y-%m-%d\"}\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[Tool result: 2025-06-06]\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Today is June 6th, 2025." ] } ], "execution_count": 9 }, { "metadata": { "id": "esIpCjhxzHmf" }, "cell_type": "markdown", "source": [ "## Next steps\n", "\n", "* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.\n", "* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.\n" ] } ], "metadata": { "colab": { "last_runtime": {}, "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }