from typing import Optional import mlx.core as mx import mlx.nn as nn from ..base import InputEmbeddingsFeatures from .audio import AudioEncoder from .config import ModelConfig from .language import LanguageModel, RMSNormNoScale from .vision import VisionModel def masked_scatter(input_tensor, mask, source): mask_flat = mask.flatten().astype(mx.int32) indices = mx.cumsum(mask_flat) - 1 aligned = source.flatten()[indices % source.size] return mx.where(mask_flat, aligned, input_tensor.flatten()).reshape( input_tensor.shape ) class MultimodalEmbedder(nn.Module): """Projects soft tokens from vision/audio into language model space.""" def __init__(self, embedding_dim: int, text_hidden_size: int, eps: float = 1e-6): super().__init__() self.embedding_projection = nn.Linear( embedding_dim, text_hidden_size, bias=False ) self.embedding_pre_projection_norm = RMSNormNoScale(embedding_dim, eps=eps) def __call__(self, inputs_embeds: mx.array) -> mx.array: normed = self.embedding_pre_projection_norm(inputs_embeds) return self.embedding_projection(normed) class Model(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.model_type = config.model_type self.config = config # Text self.language_model = LanguageModel(config.text_config) self.vocab_size = config.text_config.vocab_size # Vision self.vision_tower = VisionModel(config.vision_config) self.embed_vision = MultimodalEmbedder( embedding_dim=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, eps=config.vision_config.rms_norm_eps, ) # Audio if config.audio_config is not None: self.audio_tower = AudioEncoder(config.audio_config) audio_output_dim = ( config.audio_config.output_proj_dims or config.audio_config.hidden_size )