# Gemma4_(E2B)-Multimodal.ipynb — extracted cells # Source: https://github.com/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb # ===== CELL 0 (markdown) ===== # This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers. # ===== CELL 1 (code) ===== !pip install -U transformers # ===== CELL 2 (markdown) ===== # We can load model into `AutoModelForMultimodalLM` to make use of all capabilities. # ===== CELL 3 (code) ===== import torch from PIL import Image from transformers import AutoModelForMultimodalLM, AutoProcessor #model_list = ["google/gemma-4-26B-A4B-it", "google/gemma-4-E4B-it", # "google/gemma-4-E2B-it", "google/gemma-4-31B-it"] model_id = "google/gemma-4-E2B-it" model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map="auto") processor = AutoProcessor.from_pretrained(model_id) # ===== CELL 4 (markdown) ===== # ## Code completion # ===== CELL 5 (markdown) ===== # We give Gemma-4 a website screenshot to reproduce the code. # ===== CELL 6 (code) ===== messages = [ { "role": "user", "content": [ { "type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png", }, {"type": "text", "text": "Write HTML code for this page."}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, enable_thinking=True, ).to(model.device) output = model.generate(**inputs, max_new_tokens=4000) # ===== CELL 7 (code) ===== input_len = inputs.input_ids.shape[-1] generated_text_ids = output[0][input_len:] generated_text = processor.decode(generated_text_ids, skip_special_tokens=True) result = processor.parse_response(generated_text) print(result["content"]) # ===== CELL 8 (markdown) ===== # ## Video Inference # ===== CELL 9 (markdown) ===== # We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`. # ===== CELL 10 (code) ===== messages = [ { "role": "user", "content": [ {"type": "video", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"}, {"type": "text", "text": "What is happening in the video? What is the song about?"}, ], }, ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, load_audio_from_video=True, ).to(model.device) output = model.generate(**inputs, max_new_tokens=200) input_len = inputs.input_ids.shape[-1] generated_text_ids = output[0][input_len:] generated_text = processor.decode(generated_text_ids, skip_special_tokens=True) result = processor.parse_response(generated_text) # ===== CELL 11 (code) ===== print(result["content"]) # ===== CELL 12 (markdown) ===== # ## Multimodal Function Calling # ===== CELL 13 (code) ===== import re WEATHER_TOOL = { "type": "function", "function": { "name": "get_weather", "description": "Gets the current weather for a specific location.", "parameters": { "type": "object", "properties": { "city": {"type": "string", "description": "The city name"}, }, "required": ["city"], }, }, } tools = [WEATHER_TOOL] messages = [ {"role": "user", "content": [ {"type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg"}, {"type": "text", "text": "What is the city in this image? Check the weather there right now."}, ]}, ] inputs = processor.apply_chat_template( messages, tools=[WEATHER_TOOL], tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, enable_thinking=True, ).to(model.device) # ===== CELL 14 (code) ===== output = model.generate(**inputs, max_new_tokens=1000) # ===== CELL 15 (code) ===== input_len = inputs.input_ids.shape[-1] generated_text_ids = output[0][input_len:] generated_text = processor.decode(generated_text_ids, skip_special_tokens=True) result = processor.parse_response(generated_text) # ===== CELL 16 (code) ===== print(result["content"]) # ===== CELL 17 (markdown) ===== # # Any-to-any inference # ===== CELL 18 (markdown) ===== # We can also run the model with `any-to-any` pipeline. # ===== CELL 19 (code) ===== from transformers import pipeline pipe = pipeline("any-to-any", model="google/gemma-4-e2b-it") messages = [ { "role": "user", "content": [ { "type": "video", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4", }, {"type": "text", "text": "What is happening in this video?"}, ], } ] # ===== CELL 20 (code) ===== pipe(messages)#, load_audio_from_video=True) # ===== CELL 21 (code) ===== messages = [ { "role": "user", "content": [ { "type": "video", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4", }, {"type": "text", "text": "What is happening in this video?"}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) inputs = inputs.to(model.device) generated_ids = model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text) # ===== CELL 22 (markdown) ===== # # Object detection and pointing # ===== CELL 23 (code) ===== import re import torch from transformers.image_utils import load_image from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as patches import json # ===== CELL 24 (code) ===== image_url = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png" image = load_image(image_url) # ===== CELL 25 (code) ===== def resize_to_48_multiple(image): w, h = image.size new_w = (w // 48) * 48 new_h = (h // 48) * 48 return image.crop((0, 0, new_w, new_h)) # ===== CELL 26 (code) ===== def inputs_for_object_detection(image, what_object): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": f"What's the bounding box for the {what_object} in the image?"} ] } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", enable_thinking=False, ) return inputs.to(model.device) # ===== CELL 27 (code) ===== def extract_json(text: str): text = text.strip() text = re.sub(r"^```(?:json)?\s*", "", text) text = re.sub(r"\s*```$", "", text) # Try direct parse first try: return json.loads(text) except json.JSONDecodeError: pass # Fallback: extract first JSON object or array match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL) if match: candidate = match.group(1) return json.loads(candidate) raise ValueError("No valid JSON found") # ===== CELL 28 (code) ===== def detect_object(image_url, what_object): image = load_image(image_url) image = resize_to_48_multiple(image) inputs = inputs_for_object_detection(image, what_object) input_len = inputs["input_ids"].shape[-1] generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False) generated = processor.decode(generated_outputs[0, input_len:]) parsed_json = extract_json(generated)[0] return parsed_json # ===== CELL 29 (code) ===== def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)): dpi = 72 width, height = image.size fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0}) ax.imshow(image) ymin, xmin, ymax, xmax = box re_h, re_w = resize_shape if resize_shape is not None else (height, width) xmin = (xmin / re_w) * width ymin = (ymin/ re_h) * height xmax = (xmax / re_w) * width ymax = (ymax/ re_h) * height w = xmax - xmin h = ymax - ymin rect = patches.Rectangle( (xmin, ymin), w, h, linewidth=10, edgecolor="green", facecolor="none" ) ax.add_patch(rect) if label is not None: ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor="yellow", alpha=0.5)) plt.axis("off") plt.savefig(f"boxes_{i}.png") plt.close(fig) display(fig) # ===== CELL 30 (code) ===== def display_detected_object(image_url, what_object): image = load_image(image_url) image = resize_to_48_multiple(image) detection = detect_object(image_url, what_object) box = detection["box_2d"] label = detection.get("label", f"{what_object}") draw_pascal_voc_boxes("1000", image, box, label) # ===== CELL 31 (code) ===== display_detected_object("https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png", "bike") # ===== CELL 32 (markdown) ===== # ## Captioning # ===== CELL 33 (code) ===== messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png"}, {"type": "text", "text": "Write single detailed caption for this image."}, ], }, ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ).to(model.device) output = model.generate(**inputs, max_new_tokens=512) input_len = inputs.input_ids.shape[-1] generated_text_ids = output[0][input_len:] generated_text = processor.decode(generated_text_ids, skip_special_tokens=True) result = processor.parse_response(generated_text) print(result["content"]) # ===== CELL 34 (markdown) ===== # ## Audio Understanding # ===== CELL 35 (code) ===== messages = [ { "role": "user", "content": [ {"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"}, {"type": "text", "text": "Can you describe this audio in detail?"}, ], }, ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ).to(model.device) output = model.generate( **inputs, max_new_tokens=1000, do_sample=False, ) print(processor.decode(output[0], skip_special_tokens=True))