# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/aria/modular_aria.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_aria.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from functools import partial from typing import Callable, List, Optional, Tuple, Union from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, ) from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_available from ..auto import AutoModel, AutoModelForCausalLM from .configuration_aria import AriaConfig, AriaTextConfig if is_torch_available(): import torch from torch import nn if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "AriaTextConfig" class AriaTextRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ AriaTextRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class AriaProjectorMLP(nn.Module): """ Feed-Forward Network module for the Aria Projector. Args: in_features (`int`): Input embedding dimension. hidden_features (`int`): Hidden dimension of the feed-forward network. output_dim (`int`): Output dimension. """ def __init__(self, in_features, hidden_features, output_dim): super().__init__() self.linear_in = nn.Linear(in_features, hidden_features, bias=False) self.linear_out = nn.Linear(hidden_features, output_dim, bias=False) self.act = ACT2FN["gelu_new"] def forward(self, hidden_states): hidden_states = self.act(self.linear_in(hidden_states)) hidden_states = self.linear_out(hidden_states) return hidden_states class AriaCrossAttention(nn.Module): """ Aria Cross-Attention module. Args: config (`AriaConfig`): The configuration to use. """ def __init__(self, config: AriaConfig, dropout_rate: float = 0): super().__init__() hidden_size = config.vision_config.hidden_size num_heads = config.vision_config.num_attention_heads self.num_heads = num_heads self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48 self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True) self.linear = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout_rate) self.layer_norm = nn.LayerNorm(hidden_size) self.layer_norm_kv = nn.LayerNorm(hidden_size) def forward(self, key_value_states, hidden_states, attn_mask=None): """ Forward pass of the AriaCrossAttention module. Args: key_value_states (`torch.Tensor`): Input tensor for key and value. hidden_states (`torch.Tensor`): Input tensor for query. attn_mask (`torch.Tensor`, *optional*, defaults to None): Attention mask. Returns: torch.Tensor: Output tensor after cross-attention. """ query = self.q_proj(self.layer_norm(hidden_states)) key_value_states = self.layer_norm_kv(key_value_states) key = self.k_proj(key_value_states) value = self.v_proj(key_value_states) attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) attn_output = self.dropout(self.linear(attn_output)) return attn_output class AriaProjector(nn.Module): """ Aria Projector module. This module projects vision features into the language model's embedding space, enabling interaction between vision and language components. Args: config (`AriaConfig`): Configuration object for the model. """ def __init__( self, config: AriaConfig, ): super().__init__() self.patch_to_query_dict = config.projector_patch_to_query_dict self.in_features = config.vision_config.hidden_size self.num_heads = config.vision_config.num_attention_heads self.kv_dim = config.vision_config.hidden_size self.hidden_features = config.text_config.hidden_size self.output_dim = config.text_config.hidden_size self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features)) self.cross_attn = AriaCrossAttention(config) self.layer_norm = nn.LayerNorm(self.in_features) self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim) def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): """ Forward pass of the Projector module. Args: key_value_states (`torch.Tensor`): Input tensor of shape (batch_size, num_patches, kv_dim). attn_mask (`torch.Tensor`, *optional*, default is None): Attention mask. Returns: `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim). """ batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1] if num_patches not in self.patch_to_query_dict.keys(): raise KeyError( f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}." ) query_num = self.patch_to_query_dict[num_patches] queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1) if attn_mask is not None: attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask) out = self.feed_forward(self.layer_norm(attention_out)) return out class AriaSharedExpertsMLP(nn.Module): """ Shared Expert MLP for shared experts. Unlike routed experts, shared experts process all tokens without routing. This class reconfigures the intermediate size in comparison to the LlamaMLP. Args: config (`AriaTextConfig`): Configuration object for the Aria language model. """ def __init__(self, config: AriaTextConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert): """ Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. Args: token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features). expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features). tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. Returns: torch.Tensor: Output tensor of shape (num_tokens, out_features). """ num_tokens = token_states.shape[0] out_features = expert_weights.shape[-1] output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device) cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) # Insert zero at the beginning for offset index's convenience zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) for expert_num in range(expert_weights.shape[0]): start = cumsum_num_tokens[expert_num] end = cumsum_num_tokens[expert_num + 1] tokens = token_states[start:end] out = torch.matmul(tokens, expert_weights[expert_num]) output[start:end] = out return output class AriaGroupedExpertsGemm(nn.Module): """ Grouped GEMM (General Matrix Multiplication) module for efficient expert computation. This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm) for optimized performance. If the grouped_gemm library is not installed, it gracefully falls back to a sequential GEMM implementation, which may be slower but ensures functionality. Args: in_features (`int`): Number of input features. out_features (`int`): Number of output features. groups (`int`): Number of expert groups. """ def __init__(self, in_features, out_features, groups): super().__init__() self.in_features = in_features self.out_features = out_features self.groups = groups self.weight = nn.Parameter(torch.empty(groups, in_features, out_features)) def forward(self, input, tokens_per_expert): """ Perform grouped matrix multiplication. Args: input (`torch.Tensor`): Input tensor of shape (num_tokens, in_features). tokens_per_expert (`torch.Tensor`): Number of tokens assigned to each expert. Returns: torch.Tensor: Output tensor of shape (num_tokens, out_features). """ return sequential_experts_gemm( input, self.weight, tokens_per_expert.cpu(), ) class AriaGroupedExpertsMLP(nn.Module): """ Grouped MLP module for Mixture of Experts. Args: config (`AriaTextConfig`): Configuration object for the model. """ def __init__(self, config: AriaTextConfig) -> None: super().__init__() self.config = config self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts) self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts) def forward(self, permuted_tokens, tokens_per_expert): """ Forward pass of the Grouped MLP. Args: permuted_tokens (torch.Tensor): Permuted input tokens. tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. Returns: torch.Tensor: Output tensor after passing through the MLP. """ fc1_output = self.fc1(permuted_tokens, tokens_per_expert) projection, gate = torch.chunk(fc1_output, 2, dim=-1) fc1_output = nn.functional.silu(projection) * gate fc2_output = self.fc2(fc1_output, tokens_per_expert) return fc2_output # Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587 class AriaTextMoELayer(nn.Module): """ Aria Text Mixture of Experts (MoE) Layer. This layer applies a gating mechanism to route input tokens to different experts. Args: config (`AriaTextConfig`): Configuration object for the text component of the model. """ def __init__(self, config: AriaTextConfig): super().__init__() self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) self.experts = AriaGroupedExpertsMLP(config) self.shared_experts = AriaSharedExpertsMLP(config) self.config = config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Forward pass of the MoE Layer. Args: hidden_states (`torch.Tensor`): Input tensor of shape (batch_size, sequence_length, hidden_size). Returns: torch.Tensor: Output tensor after passing through the MoE layer. Process: 1. Route tokens to experts using the router. 2. Permute tokens based on routing decisions. 3. Process tokens through experts. 4. Unpermute and combine expert outputs. 5. Add shared expert output to the final result. """ original_shape = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_states.size(-1)) # Top K Routing logits = self.router(hidden_states) top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) scores = nn.functional.softmax(top_logits, dim=-1) original_dtype = top_indices.dtype tokens_per_expert = torch.histc( top_indices.flatten().to(torch.float32), bins=self.config.moe_num_experts, min=0, max=self.config.moe_num_experts - 1, ).to(original_dtype) indices = top_indices # Token permutation flatten_indices = indices.view(-1) sorted_indices = torch.argsort(flatten_indices) permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk) # Process through experts expert_output = self.experts(permuted_tokens, tokens_per_expert) # Token unpermutation unpermuted_tokens = torch.zeros( (scores.shape[0] * self.config.moe_topk, expert_output.size(1)), dtype=expert_output.dtype, device=expert_output.device, ) unpermuted_tokens.index_copy_(0, sorted_indices, expert_output) unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1)) output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape) # Add shared expert output shared_expert_output = self.shared_experts(hidden_states.view(original_shape)) return output + shared_expert_output def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class AriaTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: AriaTextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class AriaTextDecoderLayer(nn.Module): """ Aria Text Decoder Layer. This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network. Args: config (`AriaTextConfig`): Configuration object for the text component of the model. layer_idx (`int`): Index of the layer. """ def __init__(self, config: AriaTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx) self.mlp = AriaTextMoELayer(config) self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class AriaTextPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = AriaConfig base_model_prefix = "model" _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"] supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, AriaGroupedExpertsGemm): module.weight.data.normal_(mean=0.0, std=std) elif isinstance(module, nn.Conv2d): module.weight.data.normal_(mean=0.0, std=std) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() ARIA_TEXT_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`AriaTextConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Aria Model outputting raw hidden-states without any specific head on top.", ARIA_TEXT_START_DOCSTRING, ) class AriaPreTrainedModel(PreTrainedModel): config_class = AriaTextConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["AriaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = False def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, AriaProjector): nn.init.trunc_normal_(module.query, std=std) class AriaTextRotaryEmbedding(nn.Module): def __init__(self, config: AriaTextConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ARIA_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare AriaText Model outputting raw hidden-states without any specific head on top.", ARIA_TEXT_START_DOCSTRING, ) class AriaTextModel(AriaTextPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AriaTextDecoderLayer`] Args: config: AriaTextConfig """ def __init__(self, config: AriaTextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = AriaTextRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @can_return_tuple @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache if not isinstance(past_key_values, (type(None), Cache)): raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) if isinstance(attention_mask, BlockMask): return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): """ Aria model for causal language modeling tasks. This class extends `LlamaForCausalLM` to incorporate the Mixture of Experts (MoE) approach, allowing for more efficient and scalable language modeling. Args: config (`AriaTextConfig`): Configuration object for the model. """ _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = AriaTextConfig def __init__(self, config: AriaTextConfig): super().__init__(config) self.model = AriaTextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from transformers import AutoTokenizer, AriaTextForCausalLM >>> model = AriaTextForCausalLM.from_pretrained("meta-aria_text/AriaText-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-aria_text/AriaText-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @dataclass class AriaCausalLMOutputWithPast(ModelOutput): """ Base class for Aria causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None ARIA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor`, *optional*): Input token IDs. pixel_values (`torch.FloatTensor`, *optional*): Pixel values of the images. pixel_mask (`torch.LongTensor`, *optional*): Mask for the pixel values. attention_mask (`torch.Tensor`, *optional*): Attention mask. position_ids (`torch.LongTensor`, *optional*): Position IDs. past_key_values (`List[torch.FloatTensor]`, *optional*): Past key values for efficient processing. inputs_embeds (`torch.FloatTensor`, *optional*): Input embeddings. labels (`torch.LongTensor`, *optional*): Labels for computing the language modeling loss. use_cache (`bool`, *optional*): Whether to use the model's cache mechanism. output_attentions (`bool`, *optional*): Whether to output attention weights. output_hidden_states (`bool`, *optional*): Whether to output hidden states. return_dict (`bool`, *optional*): Whether to return a `ModelOutput` object. logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`. Otherwise, slice according to the 1D tensor in the sequence length dimension cache_position (`torch.LongTensor`, *optional*): Cache positions. **loss_kwargs: Additional keyword arguments for loss calculation. """ ARIA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config (`AriaConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( """Aria model for conditional generation tasks. This model combines a vision tower, a multi-modal projector, and a language model to perform tasks that involve both image and text inputs.""", ARIA_START_DOCSTRING, ) class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): config_class = AriaConfig _supports_flash_attn_2 = False _supports_flex_attn = False _supports_sdpa = False _tied_weights_keys = ["language_model.lm_head.weight"] def __init__(self, config: AriaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = AriaProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" self.post_init() def _create_patch_attention_mask(self, pixel_mask): if pixel_mask is None: return None patches_subgrid = pixel_mask.unfold( dimension=1, size=self.vision_tower.config.patch_size, step=self.vision_tower.config.patch_size, ) patches_subgrid = patches_subgrid.unfold( dimension=2, size=self.vision_tower.config.patch_size, step=self.vision_tower.config.patch_size, ) return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() def get_image_features( self, pixel_values: torch.FloatTensor, pixel_mask: Optional[torch.FloatTensor] = None, vision_feature_layer: int = -1, ): patch_attention_mask = self._create_patch_attention_mask(pixel_mask) image_outputs = self.vision_tower( pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True ) image_attn_mask = None if patch_attention_mask is not None: flattened_mask = patch_attention_mask.flatten(1) image_attn_mask = torch.logical_not(flattened_mask) selected_image_feature = image_outputs.hidden_states[vision_feature_layer] image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, ) -> AriaCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from transformers import AutoProcessor, AutoModel >>> from transformers.image_utils import load_image >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria") >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", torch_dtype=torch.bfloat16, device_map="auto") >>> # Create inputs >>> messages = [ ... { ... "role": "user", ... "content": [ ... {"type": "image"}, ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."}, ... {"type": "image"}, ... {"type": "text", "text": "What can we see in this image?"}, ... ] ... }, ... { ... "role": "user", ... "content": [ ... {"type": "image"}, ... {"type": "text", "text": "In which city is that bridge located?"}, ... ] ... } ... ] >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages] >>> images = [[image1, image2], [image3]] >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device) >>> # Generate >>> generated_ids = model.generate(**inputs, max_new_tokens=256) >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) >>> print(generated_texts[0]) Assistant: There are buildings, trees, lights, and water visible in this image. >>> print(generated_texts[1]) Assistant: The bridge is in San Francisco. ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images if pixel_values is not None and inputs_embeds.shape[1] != 1: if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) ) n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] else: image_embeds = input_ids == self.config.image_token_index special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) image_features = self.get_image_features( pixel_values=pixel_values, pixel_mask=pixel_mask, vision_feature_layer=self.config.vision_feature_layer, ) n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] n_image_features = n_images * n_features_per_image if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs: CausalLMOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, logits_to_keep=logits_to_keep, cache_position=cache_position, ) logits = outputs.logits loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs ) return AriaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, pixel_mask=None, attention_mask=None, cache_position=None, logits_to_keep=None, **kwargs, ): model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values model_inputs["pixel_mask"] = pixel_mask return model_inputs __all__ = [ "AriaForConditionalGeneration", "AriaPreTrainedModel", "AriaTextPreTrainedModel", "AriaTextModel", "AriaTextForCausalLM", ]
Memory