# === HEADER (license + imports) === # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_gemma4.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2026 the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections.abc import Callable from dataclasses import dataclass from functools import cached_property from typing import Optional import torch from torch import nn from torch.nn import functional as F from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin from ...integrations import use_experts_implementation, use_kernelized_func from ...masking_utils import ( create_bidirectional_mask, create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_accelerate_available, torch_compilable_check, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ..auto.modeling_auto import AutoModel from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig if is_accelerate_available(): from accelerate.hooks import add_hook_to_module @dataclass @auto_docstring( custom_intro=""" Base class for Gemma4 outputs, with hidden states and attentions. """ ) class Gemma4ModelOutputWithPast(BaseModelOutputWithPast): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): # === CLASS/FUNCTION OUTLINE (signatures + short body) === @dataclass @auto_docstring( custom_intro=""" Base class for Gemma4 outputs, with hidden states and attentions. """ ) class Gemma4ModelOutputWithPast(BaseModelOutputWithPast): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): ... @dataclass @auto_docstring( custom_intro=""" Base class for Gemma4 causal language model (or autoregressive) outputs. """ ) class Gemma4CausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). ... @dataclass @auto_docstring class Gemma4AudioModelOutput(BaseModelOutputWithPooling): r""" attention_mask (`torch.BoolTensor`, *optional*): A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding. """ attention_mask: torch.BoolTensor | None = None class Gemma4ClippableLinear(nn.Module): def __init__( self, ... class Gemma4RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): super().__init__() self.eps = eps self.with_scale = with_scale if self.with_scale: self.weight = nn.Parameter(torch.ones(dim), requires_grad=True) def _norm(self, hidden_states: torch.Tensor): mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX return hidden_states * torch.pow(mean_squared, -0.5) ... class Gemma4AudioRelPositionalEncoding(nn.Module): """Sinusoidal relative positional encoding for the audio encoder. Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with concatenated [sin..., cos...] layout matching the original Gemma4 convention. """ inv_timescales: torch.Tensor def __init__(self, config: Gemma4AudioConfig): super().__init__() self.hidden_size = config.hidden_size self.context_size = ( config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right ... class Gemma4AudioAttention(nn.Module): """Chunked local attention with relative position bias""" def __init__(self, config: Gemma4AudioConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.attention_logits_soft_cap = config.attention_logit_cap self.head_dim = config.hidden_size // config.num_attention_heads self.num_heads = config.num_attention_heads self.q_scale = (self.head_dim**-0.5) / math.log(2) self.k_scale = math.log(1 + math.e) / math.log(2) ... class Gemma4AudioSubSampleConvProjectionLayer(nn.Module): def __init__(self, in_channels, out_channels, norm_eps): super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(2, 2), padding=1, bias=False, ) self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False) self.act = nn.ReLU() ... class Gemma4AudioSubSampleConvProjection(nn.Module): def __init__(self, config: Gemma4AudioConfig): super().__init__() self.layer0 = Gemma4AudioSubSampleConvProjectionLayer( in_channels=1, out_channels=config.subsampling_conv_channels[0], norm_eps=config.rms_norm_eps, ) self.layer1 = Gemma4AudioSubSampleConvProjectionLayer( in_channels=config.subsampling_conv_channels[0], out_channels=config.subsampling_conv_channels[1], norm_eps=config.rms_norm_eps, ) proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1] ... class Gemma4AudioFeedForward(nn.Module): def __init__(self, config: Gemma4AudioConfig): super().__init__() self.config = config self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4) self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size) self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size) self.post_layer_norm = Gemma4RMSNorm(config.hidden_size) self.act_fn = ACT2FN[config.hidden_act] self.gradient_clipping = config.gradient_clipping self.post_layer_scale = config.residual_weight ... class Gemma4AudioCausalConv1d(nn.Conv1d): # def __init__( # self, # in_channels: int, # out_channels: int, # kernel_size: int, # # cache_key: str, # stride: int = 1, # dilation: int = 1, # bias: bool = True, # ): # super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias) # self.cache_key = cache_key ... class Gemma4AudioLightConv1d(nn.Module): def __init__(self, config: Gemma4AudioConfig): super().__init__() self.config = config self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2) self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) self.depthwise_conv1d = Gemma4AudioCausalConv1d( in_channels=config.hidden_size, out_channels=config.hidden_size, kernel_size=config.conv_kernel_size, groups=config.hidden_size, bias=False, ) ... class Gemma4AudioLayer(nn.Module): def __init__(self, config: Gemma4AudioConfig, layer_idx: int): super().__init__() self.config = config self.feed_forward1 = Gemma4AudioFeedForward(config) self.feed_forward2 = Gemma4AudioFeedForward(config) self.self_attn = Gemma4AudioAttention(config, layer_idx) self.lconv1d = Gemma4AudioLightConv1d(config) self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size) self.norm_post_attn = Gemma4RMSNorm(config.hidden_size) self.norm_out = Gemma4RMSNorm(config.hidden_size) ... class Gemma4VisionPatchEmbedder(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.patch_size = config.patch_size self.position_embedding_size = config.position_embedding_size self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False) self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size)) def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor: """Prepare patch positions map for matmul with positon embedding table.""" # Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul. ... class Gemma4VisionPooler(nn.Module): """Scaling and optional spatial pooling for vision encodings""" def __init__(self, config: Gemma4VisionConfig): super().__init__() self.hidden_size = config.hidden_size self.root_hidden_size = self.hidden_size**0.5 def _avg_pool_by_positions( self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int ) -> tuple[torch.Tensor, torch.Tensor]: """ 2D spatial pooling according to patch positions. Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between ... class Gemma4VisionMLP(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size) self.act_fn = ACT2FN[config.hidden_activation] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj ... class Gemma4VisionRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: Gemma4VisionConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = self.config.rope_parameters["rope_type"] rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] ... def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1): """Applies Rotary Position Embedding to the query and key tensors. Args: x (`torch.Tensor`): The tensor to embed. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. ... def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, ... def apply_multidimensional_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor, unsqueeze_dim: int = 2, ) -> torch.Tensor: """Applies multidimensional RoPE to inputs. Args: x (`torch.Tensor`): The tensor to embed. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): ... @use_kernelized_func(apply_rotary_pos_emb) class Gemma4VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma4VisionConfig, layer_idx: int): super().__init__() self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = 1.0 self.attention_dropout = self.config.attention_dropout self.is_causal = False ... class Gemma4VisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma4VisionConfig, layer_idx: int): super().__init__() self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx) self.mlp = Gemma4VisionMLP(config) self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) def forward( ... class Gemma4VisionEncoder(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.config = config self.num_layers = config.num_hidden_layers self.rotary_emb = Gemma4VisionRotaryEmbedding(config) self.layers = nn.ModuleList( [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)] ) def forward( self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, ... class Gemma4TextMLP(nn.Module): def __init__(self, config: Gemma4TextConfig, layer_idx: int): super().__init__() first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_activation] ... class Gemma4TextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.layer_types = set(config.layer_types) self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {} self.rope_type: dict[str, str] = {} for layer_type in self.layer_types: ... @use_kernelized_func(apply_rotary_pos_emb) class Gemma4TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma4TextConfig, layer_idx: int): super().__init__() self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.config = config self.layer_idx = layer_idx self.is_sliding = self.layer_type == "sliding_attention" self.sliding_window = config.sliding_window if self.is_sliding else None self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding ... @use_experts_implementation class Gemma4TextExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config: Gemma4TextConfig): super().__init__() self.num_experts = config.num_experts self.hidden_dim = config.hidden_size self.intermediate_dim = config.moe_intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) self.act_fn = ACT2FN[config.hidden_activation] def forward( ... class Gemma4TextRouter(nn.Module): def __init__(self, config: Gemma4TextConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.scalar_root_size = self.hidden_size**-0.5 self.eps = config.rms_norm_eps self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False) self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.scale = nn.Parameter(torch.ones(self.hidden_size)) self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ... class Gemma4TextDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int): super().__init__() self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx) self.mlp = Gemma4TextMLP(config, layer_idx) self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.register_buffer("layer_scalar", torch.ones(1)) ... class Gemma4TextScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0): super().__init__(num_embeddings, embedding_dim, padding_idx) self.scalar_embed_scale = embed_scale self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) ... @auto_docstring class Gemma4PreTrainedModel(PreTrainedModel): config: Gemma4Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] _skip_keys_device_placement = ["past_key_values", "shared_kv_states"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = None # override ... @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") class Gemma4TextModel(Gemma4PreTrainedModel): config: Gemma4TextConfig input_modalities = ("text",) _can_record_outputs = { "router_logits": OutputRecorder(Gemma4TextRouter, index=0), "hidden_states": Gemma4TextDecoderLayer, "attentions": Gemma4TextAttention, } def __init__(self, config: Gemma4TextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size ... @auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.") class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_gather_output"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config: Gemma4TextConfig base_model_prefix = "model" def __init__(self, config: Gemma4TextConfig): super().__init__(config) self.model = Gemma4TextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Grab the ones from the child ... def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable: """ This creates uni/bidirectional attention mask with sliding window. """ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: left_window_size, right_window_size = sliding_window dist = q_idx - kv_idx left_mask = (dist >= 0) & (dist < left_window_size) right_mask = (dist < 0) & (-dist < right_window_size) return left_mask | right_mask return inner_mask ... class Gemma4AudioModel(Gemma4PreTrainedModel): """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.""" config: Gemma4AudioConfig main_input_name = "input_features" base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained() _can_record_outputs = { "hidden_states": Gemma4AudioLayer, "attentions": Gemma4AudioAttention, } def __init__(self, config: Gemma4AudioConfig): super().__init__(config) self.config = config ... class Gemma4VisionModel(Gemma4PreTrainedModel): """The Gemma 4 Vision Encoder.""" config = Gemma4VisionConfig _can_record_outputs = { "hidden_states": Gemma4VisionEncoderLayer, "attentions": Gemma4VisionAttention, } def __init__(self, config: Gemma4VisionConfig): super().__init__(config) self.patch_embedder = Gemma4VisionPatchEmbedder(config) self.encoder = Gemma4VisionEncoder(config) self.pooler = Gemma4VisionPooler(config) ... class Gemma4MultimodalEmbedder(nn.Module): """Embeds token ids or soft tokens for multimodal content into language model space.""" def __init__( self, multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig, text_config: Gemma4TextConfig, ): super().__init__() self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size) self.eps = multimodal_config.rms_norm_eps self.text_hidden_size = text_config.hidden_size self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False) ... def token_type_ids_mask_function( token_type_ids: torch.Tensor | None, image_group_ids: torch.Tensor | None, ) -> Callable | None: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. """ # Do not return an additional mask in this case if token_type_ids is None: return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: seq_length = image_group_ids.shape[-1] ... def create_causal_mask_mapping( config: PreTrainedConfig, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor | None, past_key_values: Cache | None, position_ids: torch.Tensor | None, mm_token_type_ids: torch.Tensor | None = None, pixel_values: torch.FloatTensor | None = None, is_training: bool = False, is_first_iteration: bool | None = None, **kwargs, ) -> dict: """ Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping ... @auto_docstring( custom_intro=""" The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a language modeling head. """ ) class Gemma4Model(Gemma4PreTrainedModel): # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch accepts_loss_kwargs = False def __init__(self, config: Gemma4Config): super().__init__(config) self.vocab_size = config.text_config.vocab_size ... @auto_docstring( custom_intro=""" The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling head. """ ) class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} accepts_loss_kwargs = False base_model_prefix = "model" def __init__(self, config: Gemma4Config): super().__init__(config) self.model = Gemma4Model(config) ...