{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "cells": [ { "cell_type": "markdown", "source": [ "This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers." ], "metadata": {} }, { "cell_type": "code", "source": [ "!pip install -U transformers" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can load model into `AutoModelForMultimodalLM` to make use of all capabilities." ], "metadata": {} }, { "cell_type": "code", "source": [ "import torch\n", "from PIL import Image\n", "\n", "from transformers import AutoModelForMultimodalLM, AutoProcessor\n", "#model_list = [\"google/gemma-4-26B-A4B-it\", \"google/gemma-4-E4B-it\",\n", "# \"google/gemma-4-E2B-it\", \"google/gemma-4-31B-it\"]\n", "model_id = \"google/gemma-4-E2B-it\"\n", "model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map=\"auto\")\n", "processor = AutoProcessor.from_pretrained(model_id)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Code completion" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "We give Gemma-4 a website screenshot to reproduce the code." ], "metadata": {} }, { "cell_type": "code", "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"image\",\n", " \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png\",\n", " },\n", " {\"type\": \"text\", \"text\": \"Write HTML code for this page.\"},\n", " ],\n", " }\n", "]\n", "\n", "inputs = processor.apply_chat_template(\n", " messages,\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", " add_generation_prompt=True,\n", " enable_thinking=True,\n", ").to(model.device)\n", "\n", "output = model.generate(**inputs, max_new_tokens=4000)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "input_len = inputs.input_ids.shape[-1]\n", "generated_text_ids = output[0][input_len:]\n", "generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n", "result = processor.parse_response(generated_text)\n", "\n", "print(result[\"content\"])" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Video Inference" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`." ], "metadata": {} }, { "cell_type": "code", "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"video\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4\"},\n", " {\"type\": \"text\", \"text\": \"What is happening in the video? What is the song about?\"},\n", " ],\n", " },\n", "]\n", "inputs = processor.apply_chat_template(\n", " messages,\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", " add_generation_prompt=True,\n", " load_audio_from_video=True,\n", ").to(model.device)\n", "output = model.generate(**inputs, max_new_tokens=200)\n", "input_len = inputs.input_ids.shape[-1]\n", "generated_text_ids = output[0][input_len:]\n", "generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n", "result = processor.parse_response(generated_text)\n" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(result[\"content\"])" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Multimodal Function Calling" ], "metadata": {} }, { "cell_type": "code", "source": [ "import re\n", "\n", "WEATHER_TOOL = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"get_weather\",\n", " \"description\": \"Gets the current weather for a specific location.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"city\": {\"type\": \"string\", \"description\": \"The city name\"},\n", " },\n", " \"required\": [\"city\"],\n", " },\n", " },\n", "}\n", "tools = [WEATHER_TOOL]\n", "\n", "messages = [\n", " {\"role\": \"user\", \"content\": [\n", " {\"type\": \"image\", \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg\"},\n", " {\"type\": \"text\", \"text\": \"What is the city in this image? Check the weather there right now.\"},\n", " ]},\n", "]\n", "\n", "inputs = processor.apply_chat_template(\n", " messages,\n", " tools=[WEATHER_TOOL],\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", " add_generation_prompt=True,\n", " enable_thinking=True,\n", ").to(model.device)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "output = model.generate(**inputs, max_new_tokens=1000)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "input_len = inputs.input_ids.shape[-1]\n", "generated_text_ids = output[0][input_len:]\n", "generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n", "result = processor.parse_response(generated_text)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(result[\"content\"])" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Any-to-any inference" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "We can also run the model with `any-to-any` pipeline." ], "metadata": {} }, { "cell_type": "code", "source": [ "from transformers import pipeline\n", "\n", "pipe = pipeline(\"any-to-any\", model=\"google/gemma-4-e2b-it\")\n", "\n", "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"video\",\n", " \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n", " },\n", " {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n", " ],\n", " }\n", "]\n" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "pipe(messages)#, load_audio_from_video=True)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"video\",\n", " \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n", " },\n", " {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n", " ],\n", " }\n", "]\n", "\n", "inputs = processor.apply_chat_template(\n", " messages,\n", " tokenize=True,\n", " add_generation_prompt=True,\n", " return_dict=True,\n", " return_tensors=\"pt\"\n", ")\n", "inputs = inputs.to(model.device)\n", "\n", "generated_ids = model.generate(**inputs, max_new_tokens=128)\n", "generated_ids_trimmed = [\n", " out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n", "]\n", "output_text = processor.batch_decode(\n", " generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n", ")\n", "print(output_text)\n" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Object detection and pointing" ], "metadata": {} }, { "cell_type": "code", "source": [ "import re\n", "import torch\n", "from transformers.image_utils import load_image\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as patches\n", "import json" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "image_url = \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\"\n", "image = load_image(image_url)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def resize_to_48_multiple(image):\n", " w, h = image.size\n", " new_w = (w // 48) * 48\n", " new_h = (h // 48) * 48\n", " return image.crop((0, 0, new_w, new_h))" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def inputs_for_object_detection(image, what_object):\n", " messages = [\n", " {\n", " \"role\": \"user\", \"content\": [\n", " {\"type\": \"image\", \"image\": image},\n", " {\"type\": \"text\", \"text\": f\"What's the bounding box for the {what_object} in the image?\"}\n", " ]\n", " }\n", " ]\n", "\n", " inputs = processor.apply_chat_template(\n", " messages,\n", " tokenize=True,\n", " add_generation_prompt=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", " enable_thinking=False,\n", " )\n", "\n", " return inputs.to(model.device)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def extract_json(text: str):\n", " text = text.strip()\n", "\n", " text = re.sub(r\"^```(?:json)?\\s*\", \"\", text)\n", " text = re.sub(r\"\\s*```$\", \"\", text)\n", "\n", " # Try direct parse first\n", " try:\n", " return json.loads(text)\n", " except json.JSONDecodeError:\n", " pass\n", "\n", " # Fallback: extract first JSON object or array\n", " match = re.search(r'(\\{.*\\}|\\[.*\\])', text, re.DOTALL)\n", " if match:\n", " candidate = match.group(1)\n", " return json.loads(candidate)\n", "\n", " raise ValueError(\"No valid JSON found\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def detect_object(image_url, what_object):\n", " image = load_image(image_url)\n", " image = resize_to_48_multiple(image)\n", " inputs = inputs_for_object_detection(image, what_object)\n", " input_len = inputs[\"input_ids\"].shape[-1]\n", " generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)\n", " generated = processor.decode(generated_outputs[0, input_len:])\n", " parsed_json = extract_json(generated)[0]\n", " return parsed_json" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):\n", " dpi = 72\n", " width, height = image.size\n", " fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})\n", "\n", " ax.imshow(image)\n", "\n", " ymin, xmin, ymax, xmax = box\n", " re_h, re_w = resize_shape if resize_shape is not None else (height, width)\n", " xmin = (xmin / re_w) * width\n", " ymin = (ymin/ re_h) * height\n", " xmax = (xmax / re_w) * width\n", " ymax = (ymax/ re_h) * height\n", "\n", " w = xmax - xmin\n", " h = ymax - ymin\n", "\n", " rect = patches.Rectangle(\n", " (xmin, ymin),\n", " w,\n", " h,\n", " linewidth=10,\n", " edgecolor=\"green\",\n", " facecolor=\"none\"\n", " )\n", " ax.add_patch(rect)\n", "\n", " if label is not None:\n", " ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor=\"yellow\", alpha=0.5))\n", "\n", " plt.axis(\"off\")\n", " plt.savefig(f\"boxes_{i}.png\")\n", " plt.close(fig)\n", " display(fig)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def display_detected_object(image_url, what_object):\n", " image = load_image(image_url)\n", " image = resize_to_48_multiple(image)\n", " detection = detect_object(image_url, what_object)\n", " box = detection[\"box_2d\"]\n", " label = detection.get(\"label\", f\"{what_object}\")\n", " draw_pascal_voc_boxes(\"1000\", image, box, label)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "display_detected_object(\"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\", \"bike\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "##\u00a0Captioning" ], "metadata": {} }, { "cell_type": "code", "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png\"},\n", " {\"type\": \"text\", \"text\": \"Write single detailed caption for this image.\"},\n", " ],\n", " },\n", "]\n", "\n", "inputs = processor.apply_chat_template(\n", " messages,\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", " add_generation_prompt=True,\n", ").to(model.device)\n", "\n", "output = model.generate(**inputs, max_new_tokens=512)\n", "input_len = inputs.input_ids.shape[-1]\n", "generated_text_ids = output[0][input_len:]\n", "generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n", "result = processor.parse_response(generated_text)\n", "print(result[\"content\"])" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Audio Understanding" ], "metadata": {} }, { "cell_type": "code", "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"audio\", \"url\": \"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3\"},\n", " {\"type\": \"text\", \"text\": \"Can you describe this audio in detail?\"},\n", " ],\n", " },\n", "]\n", "\n", "inputs = processor.apply_chat_template(\n", " messages,\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", " add_generation_prompt=True,\n", ").to(model.device)\n", "\n", "output = model.generate(\n", " **inputs,\n", " max_new_tokens=1000,\n", " do_sample=False,\n", ")\n", "\n", "print(processor.decode(output[0], skip_special_tokens=True))\n" ], "metadata": {}, "execution_count": null, "outputs": [] } ] }