{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "G3MMAcssHTML" }, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2024 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "4qxv4Sn9b8CE" }, "source": [ "
\n",
" View on ai.google.dev\n",
" | \n",
" \n",
" Run in Google Colab\n",
" | \n",
" \n",
" Run in Kaggle\n",
" | \n",
" \n",
" | \n",
" \n",
" View source on GitHub\n",
" | \n",
"
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
"\n"
],
"text/plain": [
"\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Tokenizer (type) ┃ Vocab # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n",
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n",
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Model: \"gemma_causal_lm\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ gemma_backbone │ (None, None, 2304) │ 2,614,341,888 │ padding_mask[0][0], │\n",
"│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_embedding │ (None, None, 256000) │ 589,824,000 │ gemma_backbone[0][0] │\n",
"│ (ReversibleEmbedding) │ │ │ │\n",
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n",
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 2,614,341,888 (9.74 GB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 2,614,341,888 (9.74 GB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gemma_lm.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "ArZPOzFpVp6S" }, "source": [ "As you can see from the summary, the model has 2.6 billion trainable parameters.\n", "\n", "Note: For purposes of naming the model (\"2B\"), the embedding layer is not counted against the number of parameters." ] }, { "cell_type": "markdown", "metadata": { "id": "1WpS39TBYql9" }, "source": [ "### Define formatting helper functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3-obTC1jZGpZ" }, "outputs": [], "source": [ "from IPython.display import Markdown\n", "import textwrap\n", "\n", "def display_chat(prompt, text):\n", " formatted_prompt = \"🙋♂️
\" + prompt + \"\"\n", " text = text.replace('•', ' *')\n", " text = textwrap.indent(text, '> ', predicate=lambda _: True)\n", " formatted_text = \"🤖\\n\\n\" + text + \"\\n\"\n", " return Markdown(formatted_prompt+formatted_text)\n", "\n", "def to_markdown(text):\n", " text = text.replace('•', ' *')\n", " return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))" ] }, { "cell_type": "markdown", "metadata": { "id": "5ca54e8c" }, "source": [ "## Building the chatbot\n", "\n", "The Gemma instruction-tuned model `gemma2_instruct_2b_en` is fine-tuned to understand the following turn tokens:\n", "\n", "```\n", "
Tell me, in a few words, how to compute all prime numbers up to 1000?🤖\n", "\n", "> **Sieve of Eratosthenes.** \n", ">
Now in Python! No numpy, please!🤖\n", "\n", "> ```python\n", "> def sieve_of_eratosthenes(n):\n", "> \"\"\"Returns a list of prime numbers up to n.\"\"\"\n", "> primes = [True] * (n + 1)\n", "> primes[0] = primes[1] = False\n", "> for i in range(2, int(n**0.5) + 1):\n", "> if primes[i]:\n", "> for j in range(i * i, n + 1, i):\n", "> primes[j] = False\n", "> return [i for i, is_prime in enumerate(primes) if is_prime]\n", "> \n", "> primes = sieve_of_eratosthenes(1000)\n", "> print(primes)\n", "> ```\n", "> \n", "> **Explanation:**\n", "> \n", "> 1. **Initialization:**\n", "> - `primes = [True] * (n + 1)`: Creates a list `primes` of boolean values, initially assuming all numbers are prime.\n", "> - `primes[0] = primes[1] = False`: Sets 0 and 1 as non-prime.\n", "> \n", "> 2. **Iteration:**\n", "> - `for i in range(2, int(n**0.5) + 1):`: Iterates from 2 to the square root of `n`. We only need to check up to the square root because any composite number must have a prime factor less than or equal to its square root.\n", "> - `if primes[i]:`: If `i` is marked as prime:\n", "> - `for j in range(i * i, n + 1, i):`: Marks all multiples of `i` as non-prime.\n", "> \n", "> 3. **Result:**\n", "> - `return [i for i, is_prime in enumerate(primes) if is_prime]`: Creates a list of indices where `primes[i]` is True, representing the prime numbers.\n", "> \n", "> \n", "> Let me know if you'd like a more detailed explanation of any part! \n", ">
Thank you, it works! Can you explain the code in French?🤖\n", "\n", "> Bien sûr ! Voici une explication du code en français :\n", "> \n", "> ```python\n", "> def sieve_of_eratosthenes(n):\n", "> \"\"\"Retourne une liste de nombres premiers jusqu'à n.\"\"\"\n", "> primes = [True] * (n + 1)\n", "> primes[0] = primes[1] = False\n", "> for i in range(2, int(n**0.5) + 1):\n", "> if primes[i]:\n", "> for j in range(i * i, n + 1, i):\n", "> primes[j] = False\n", "> return [i for i, is_prime in enumerate(primes) if is_prime]\n", "> \n", "> primes = sieve_of_eratosthenes(1000)\n", "> print(primes)\n", "> ```\n", "> \n", "> **Explication:**\n", "> \n", "> 1. **Initialisation:**\n", "> - `primes = [True] * (n + 1)`: Crée une liste `primes` de valeurs booléennes, initialement supposant que tous les nombres sont premiers.\n", "> - `primes[0] = primes[1] = False`: Définit 0 et 1 comme non-premiers.\n", "> \n", "> 2. **Itération:**\n", "> - `for i in range(2, int(n**0.5) + 1):`: Itère de 2 jusqu'à la racine carrée de `n`. Nous ne devons vérifier que jusqu'à la racine carrée car tout nombre composite doit avoir un facteur premier inférieur ou égal à sa racine carrée.\n", "> - `if primes[i]:`: Si `i` est considéré comme premier:\n", "> - `for j in range(i * i, n + 1, i):`: Marquer tous les multiples de `i` comme non-premiers.\n", "> \n", "> 3. **Resultat:**\n", "> - `return [i for i, is_prime in enumerate(primes) if is_prime]`: Crée une liste des indices où `primes[i]` est vrai, représentant les nombres premiers.\n", "> \n", "> \n", "> N'hésitez pas à me demander si vous avez besoin d'une explication plus détaillée de quelque chose! \n", ">
Great! Now add those explanations as comments in the code.🤖\n", "\n", "> ```python\n", "> def sieve_of_eratosthenes(n):\n", "> \"\"\"Retourne une liste de nombres premiers jusqu'à n.\"\"\"\n", "> # Initialise une liste de boolean avec True pour tous les nombres de 0 à n\n", "> primes = [True] * (n + 1)\n", "> # Définit 0 et 1 comme non-premiers\n", "> primes[0] = primes[1] = False\n", "> # Itère de 2 à la racine carrée de n\n", "> for i in range(2, int(n**0.5) + 1):\n", "> # Si i est considéré comme premier\n", "> if primes[i]:\n", "> # Itère sur tous les multiples de i\n", "> for j in range(i * i, n + 1, i):\n", "> # Définit les multiples de i comme non-premiers\n", "> primes[j] = False\n", "> # Retourne la liste des indices des nombres premiers\n", "> return [i for i, is_prime in enumerate(primes) if is_prime]\n", "> \n", "> primes = sieve_of_eratosthenes(1000)\n", "> print(primes)\n", "> ```\n", "> \n", "> **Explication:**\n", "> \n", "> * **Initialisation:**\n", "> * `primes = [True] * (n + 1)`: Crée une liste `primes` de valeurs booléennes, initialement supposant que tous les nombres sont premiers.\n", "> * `primes[0] = primes[1] = False`: Définit 0 et 1 comme non-premiers.\n", "> * **Itération:**\n", "> * `for i in range(2, int(n**0.5) + 1):`: Itère de 2 jusqu'à la racine carrée de `n`. Nous ne devons vérifier que jusqu'à la racine carrée car tout nombre composite doit avoir un facteur premier inférieur ou égal à sa racine carrée.\n", "> * `if primes[i]:`: Si `i` est considéré comme premier:\n", "> * `for j in range(i * i, n + 1, i):`: Marquer tous les multiples de `i` comme non-premiers.\n", "> * **Resultat:**\n", "> * `return [i for i, is_prime in enumerate(primes) if is_prime]`: Crée une liste des indices où `primes[i]` est vrai, représentant les nombres premiers. \n", "> \n", "> \n", "> \n", ">