# coding=utf-8 # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig from ...activations import ACT2FN from ...cache_utils import Cache, HybridChunkedCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_flex_attn_available, logging, replace_return_docstrings, ) from .configuration_llama4 import Llama4Config, Llama4TextConfig if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B" _CONFIG_FOR_DOC = "Llama4Config" class Llama4TextExperts(nn.Module): def __init__(self, config: Llama4Config): super().__init__() self.num_experts = config.num_local_experts self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ This should really not be run on a single machine, as we are reaching compute bound: - the inputs are expected to be "sorted" per expert already. - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape Args: hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) selected_experts (torch.Tensor): (batch_size * token_num, top_k) routing_weights (torch.Tensor): (batch_size * token_num, top_k) Returns: torch.Tensor """ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) next_states = next_states.view(-1, self.hidden_size) return next_states # Phi3MLP class Llama4TextMLP(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() if intermediate_size is None: intermediate_size = config.intermediate_size self.config = config self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False) self.activation_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x) return self.down_proj(down_proj) class Llama4TextL2Norm(torch.nn.Module): def __init__(self, eps: float = 1e-6): super().__init__() self.eps = eps def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): return self._norm(x.float()).type_as(x) def extra_repr(self): return f"eps={self.eps}" class Llama4TextRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): """ Llama4RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" class Llama4TextMoe(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts self.experts = Llama4TextExperts(config) self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) self.shared_expert = Llama4TextMLP(config) def forward(self, hidden_states): batch, seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = self.router(hidden_states).transpose(0, 1) tokens_per_expert = batch * seq_len router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1) router_scores = ( torch.full_like(router_logits.transpose(0, 1), float("-inf")) .scatter_(1, router_indices, router_top_value) .transpose(0, 1) ) # We do this to make sure we have -inf for non topK tokens before going through the ! # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this! router_indices = ( torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim) routed_in = torch.gather( input=hidden_states, dim=0, index=router_indices, ).to(hidden_states.device) # we gather inputs corresponding to each expert based on the router indices routed_in = routed_in * router_scores.reshape(-1, 1) routed_out = self.experts(routed_in) out = self.shared_expert(hidden_states) # now that we finished expert computation -> we scatter add because we gathered previously # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # this scales a lot better if you do EP! out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim)) return out, router_scores class Llama4TextRotaryEmbedding(nn.Module): def __init__(self, config: Llama4TextConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" self.rope_type = "llama3" if config.rope_scaling is not None else "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention freqs_cis = freqs_cis * self.attention_scaling return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Llama4TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Llama4TextConfig, layer_idx): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_attention_heads = config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attn_scale = config.attn_scale self.floor_scale = config.floor_scale self.attn_temperature_tuning = config.attn_temperature_tuning self.attention_dropout = config.attention_dropout self.is_causal = True self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) if self.config.use_qk_norm and self.use_rope: self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if self.use_rope: # the 16E model skips rope for long context on certain layers query_states, key_states = apply_rotary_emb( query_states, key_states, position_embeddings.to(query_states.device) ) if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm query_states = self.qk_norm(query_states) key_states = self.qk_norm(key_states) # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers if self.attn_temperature_tuning and not self.use_rope: attn_scales = ( torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 ) attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 query_states = (query_states * attn_scales).to(query_states.dtype) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Llama4TextDecoderLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Llama4TextAttention(config, layer_idx) self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope self.is_moe_layer = layer_idx in config.moe_layers if self.is_moe_layer: # the 128E model interleaves dense / sparse self.feed_forward = Llama4TextMoe(config) else: self.feed_forward = Llama4TextMLP(config, intermediate_size=config.intermediate_size_mlp) self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, chunk_causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # use local attention mask for ROPE layers if self.use_chunked_attention and chunk_causal_mask is not None: attention_mask = chunk_causal_mask # Self Attention attention_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + attention_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.feed_forward(hidden_states) if self.is_moe_layer: hidden_states, router_logits = hidden_states else: router_logits = None hidden_states = residual + hidden_states.view(residual.shape) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if output_router_logits: outputs += (router_logits,) return outputs LLAMA4_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Llama4Config`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Llama4 Model outputting raw hidden-states without any specific head on top.", LLAMA4_START_DOCSTRING, ) class Llama4PreTrainedModel(PreTrainedModel): config_class = Llama4Config supports_gradient_checkpointing = True _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = False _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") else self.config.text_config.initializer_range ) if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() LLAMA4_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Llama4 Model outputting raw hidden-states without any specific head on top.", LLAMA4_START_DOCSTRING, ) class Llama4TextModel(Llama4PreTrainedModel): _no_split_modules = ["Llama4TextDecoderLayer"] base_model_prefix = "model" config_class = Llama4TextConfig def __init__(self, config: Llama4TextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Llama4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Llama4TextRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) if use_cache and past_key_values is None: past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1]) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask, chunk_causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, use_cache=use_cache ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers freq_cis = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, chunk_causal_mask, position_ids, past_key_values, output_attentions, False, # output_router_logits is False use_cache, cache_position, freq_cis, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, chunk_causal_mask=chunk_causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=freq_cis, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) output = BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) return output if return_dict else output.to_tuple() @torch.compiler.disable(recursive=False) # the operations in this method are not compilable def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, chunked_attention_mask=None, use_cache=True, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask, attention_mask # flash does not support chunked attn TODO support flash return None, None if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: return None, None sequence_length = input_tensor.shape[1] cache_position = cache_position.to(self.device) attention_chunk_size = self.config.attention_chunk_size first_cache_position = cache_position[0] if past_key_values is not None: full_cache_length = past_key_values.get_max_cache_shape() or sequence_length else: full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length cond1 = first_cache_position >= attention_chunk_size cond2 = (first_cache_position < attention_chunk_size) & ( first_cache_position + sequence_length > attention_chunk_size ) key_length = ( torch.where( cond1, attention_chunk_size + sequence_length - 1, torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), ) if use_cache else full_cache_length ) if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0)) chunked_attention_mask = make_flex_block_causal_mask( attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets ) attention_mask = make_flex_block_causal_mask( attention_mask, query_length=sequence_length, key_length=full_cache_length, offsets=(first_cache_position, 0), ) return attention_mask, chunked_attention_mask if isinstance(attention_mask, BlockMask): return attention_mask, chunked_attention_mask # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). dtype, device = input_tensor.dtype, input_tensor.device causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=max(full_cache_length, attention_chunk_size), dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) if full_cache_length > self.config.attention_chunk_size: start_idx = max(first_cache_position - attention_chunk_size + 1, 0) end_idx = start_idx + key_length chunked_attention_mask = self.create_chunked_attention_mask( self.config.attention_chunk_size, start=start_idx, # same offset as with flex end=end_idx, device=device, ) local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well # It may be smaller than attention_chunk_size -> pad it requires_padding = local_attention_mask.shape[-1] < attention_chunk_size if requires_padding: local_attention_mask = nn.functional.pad( local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) ) # Depending on the padding, take the query tokens from the end or the cache_position if not requires_padding: chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] else: chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1) chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] if self.config._attn_implementation == "eager": min_dtype = torch.finfo(dtype).min chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"] and attention_mask.ndim == 4 and not output_attentions # Only unmask for 4d masks ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: chunked_attention_mask = chunked_attention_mask.bool() causal_mask = causal_mask.bool() if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=first_cache_position, is_training=self.training, ): causal_mask = None return causal_mask, chunked_attention_mask def create_chunked_attention_mask( self, attention_chunk_size: int, start: int, end: int, device: torch.device ) -> torch.Tensor: """ Generate the following: 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | If the chunk size is 3. This can just be appplied over the already created attention mask """ arange_vector = torch.arange(start, end, device=device) block_pos = torch.abs( arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size ) token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) mask = (block_pos == 0) & (token_pos <= 0) return mask.to(device) @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): base_model_prefix = "language_model" _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} config_class = Llama4TextConfig def __init__(self, config: Llama4TextConfig): super().__init__(config) self.model = Llama4TextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from transformers import AutoTokenizer, Llama4ForCausalLM >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @dataclass class Llama4CausalLMOutputWithPast(ModelOutput): """ Base class for Llava causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None class Llama4VisionMLP2(torch.nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.fc1 = nn.Linear(self.intermediate_size, config.projector_input_dim, bias=False) self.fc2 = nn.Linear(config.projector_output_dim, config.projector_output_dim, bias=False) self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] self.dropout = config.projector_dropout def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) return self.activation_fn(self.fc2(hidden_states)) class Llama4MultiModalProjector(nn.Module): def __init__(self, config): super().__init__() self.linear_1 = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=False, ) def forward(self, image_features): hidden_states = self.linear_1(image_features) return hidden_states def pixel_shuffle(input_tensor, shuffle_ratio): # input_tensor: [batch_size, num_patches, channels] batch_size, num_patches, channels = input_tensor.shape patch_size = int(math.sqrt(num_patches)) input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.size() reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.view( batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) return output_tensor class Llama4VisionPixelShuffleMLP(nn.Module): def __init__(self, config): super().__init__() self.pixel_shuffle_ratio = config.pixel_shuffle_ratio self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2)) self.output_dim = config.projector_output_dim self.mlp = Llama4VisionMLP2(config) def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) return self.mlp(encoded_patches) LLAVA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LlavaConfig`] or [`LlavaVisionConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ # TODO there is a different RoPE for vision encoder, defined as below def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): ndim = query.ndim shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)] return freqs_ci.view(*shape) def vision_apply_rotary_emb( query: torch.Tensor, key: torch.Tensor, freqs_ci: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_) # freqs_ci[:,:,None,:] freqs_ci = freqs_ci.to(query_.device) query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) # but this drops to 8e-3 class Llama4VisionAttention(nn.Module): def __init__(self, config: Llama4VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = 1 self.attention_dropout = config.attention_dropout self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True) def forward( self, hidden_states: torch.Tensor, freqs_ci: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape) value_states = self.v_proj(hidden_states).view(hidden_shape) query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attention_interface: Callable = eager_attention_forward # flex disable because breaks on TP 8, embed is 88 not power of 2 if self.config._attn_implementation not in ["eager", "flex_attention"]: if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, None, dropout=0.0 if not self.training else self.attention_dropout, scaling=None, is_causal=False, # HAS TO BE ENFORCED **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Llama4VisionMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class Llama4VisionEncoderLayer(nn.Module): def __init__(self, config: Llama4VisionConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Llama4VisionAttention(config) self.mlp = Llama4VisionMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) def forward( self, hidden_state: torch.Tensor, freqs_ci: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = None, ): # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) hidden_state, attn_weights = self.self_attn( hidden_state, freqs_ci=freqs_ci, attention_mask=attention_mask, ) hidden_state = residual + hidden_state # Feed forward residual = hidden_state hidden_state = self.post_attention_layernorm(hidden_state) hidden_state = self.mlp(hidden_state) hidden_state = residual + hidden_state outputs = (hidden_state,) if output_attentions: outputs += (attn_weights,) return outputs class Llama4VisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Llama4VisionEncoderLayer`]. Args: config: Llama4VisionConfig """ def __init__(self, config: Llama4VisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.config = config def forward( self, hidden_states: torch.Tensor, freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, freqs_ci, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_state=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, freqs_ci=freqs_ci, ) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = layer_outputs[0] if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class Llama4UnfoldConvolution(nn.Module): def __init__(self, config): super().__init__() kernel_size = config.patch_size if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) self.linear = nn.Linear( config.num_channels * kernel_size[0] * kernel_size[1], config.hidden_size, bias=False, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) hidden_states = hidden_states.permute(0, 2, 1) hidden_states = self.linear(hidden_states) return hidden_states class Llama4VisionRotaryEmbedding(nn.Module): def __init__(self, config): super().__init__() idx = config.image_size // config.patch_size img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) img_idx[-1, -1] = -2 # ID_CLS_TOKEN frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y freq_dim = config.hidden_size // config.num_attention_heads // 2 rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)) freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 def forward(self, hidden_states): return self.freqs_ci.to(hidden_states.device) class Llama4VisionModel(Llama4PreTrainedModel): base_model_prefix = "vision_model" _no_split_modules = ["Llama4VisionEncoderLayer"] config_class = Llama4VisionConfig def __init__(self, config: Llama4VisionConfig): super().__init__(config) self.image_size = config.image_size self.patch_size = config.patch_size self.hidden_size = config.hidden_size self.num_channels = config.num_channels self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 self.patch_embedding = Llama4UnfoldConvolution(config) self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) self.positional_embedding_vlm = nn.Parameter(self.scale * torch.randn(self.num_patches, self.hidden_size)) self.rotary_embedding = Llama4VisionRotaryEmbedding(config) # layer norms self.layernorm_pre = nn.LayerNorm(self.hidden_size) self.layernorm_post = nn.LayerNorm(self.hidden_size) # encoders self.model = Llama4VisionEncoder(config) self.vision_adapter = Llama4VisionPixelShuffleMLP(config) self.post_init() def get_input_embeddings(self): """ This function is used to fetch the first embedding layer to activate grads on inputs. """ return self.patch_embedding def forward( self, pixel_values: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: r""" Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, MllamaVisionModel >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" >>> model = MllamaVisionModel.from_pretrained(checkpoint) >>> processor = AutoProcessor.from_pretrained(checkpoint) >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> output = model(**inputs) >>> print(output.last_hidden_state.shape) torch.Size([1, 1, 4, 1025, 7680]) ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # num_concurrent_media and num_chunks are both currently 1 batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape num_concurrent_media = 1 num_chunks = 1 hidden_state = self.patch_embedding(pixel_values) _, num_patches, hidden_dim = hidden_state.shape # Add cls token hidden_state = hidden_state.reshape( batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim ) class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1]) hidden_state = torch.cat([hidden_state, class_embedding], dim=1) num_patches += 1 # Position embeddings hidden_state = hidden_state.reshape( batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim ) positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device) hidden_state = hidden_state + positional_embedding hidden_state = self.layernorm_pre(hidden_state) hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) freqs_ci = self.rotary_embedding(pixel_values) output = self.model( hidden_state, attention_mask=None, output_hidden_states=output_hidden_states, output_attentions=output_attentions, freqs_ci=freqs_ci, ) hidden_state = output.last_hidden_state hidden_state = self.layernorm_post(hidden_state) hidden_state = hidden_state[:, :-1, :] # now, we use Llama4VisionPixelShuffle + mlp to project embeddings hidden_state = self.vision_adapter(hidden_state) hidden_states = output.hidden_states if output_hidden_states else None if output_attentions: attentions = output[2] else: attentions = None if not return_dict: return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_state, hidden_states=hidden_states, attentions=attentions, ) class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): _tp_plan = {} base_model_prefix = "" config_class = Llama4Config _supports_flex_attn = True def __init__(self, config: Llama4Config): super().__init__(config) self.vision_model = Llama4VisionModel(config.vision_config) self.multi_modal_projector = Llama4MultiModalProjector(config) self.language_model = Llama4ForCausalLM(config.text_config) self.vocab_size = config.text_config.vocab_size self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: Union[int, List[int]], vision_feature_select_strategy: str, **kwargs, ): """ Obtains image last hidden states from the vision tower and apply al projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) The tensors corresponding to the input images. vision_feature_layer (`Union[int, List[int]]`): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. vision_feature_select_strategy (`str`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ if vision_feature_select_strategy not in ["default", "full"]: raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}") kwargs = {k: v for k, v in kwargs.items() if v is not None} image_outputs = self.vision_model(pixel_values, output_hidden_states=False, **kwargs) hidden_state = image_outputs.last_hidden_state return hidden_state @replace_return_docstrings(output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, List[int]]] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, image_sizes: torch.Tensor = None, **lm_kwargs, ) -> Union[Tuple, Llama4CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, LlavaForConditionalGeneration >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_new_tokens=15) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_config.vision_feature_select_strategy ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if pixel_values is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, image_sizes=image_sizes, ) original_inputs_embeds_shape = inputs_embeds.shape vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.multi_modal_projector(vision_flat) special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) final_mask = special_image_mask.to(inputs_embeds.device) inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) final_mask_1d = final_mask[..., 0].reshape(-1) num_tokens_to_fill = final_mask_1d.sum() if num_tokens_to_fill != projected_vision_flat.size(0): raise ValueError( f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " f"but multi_modal_projector returned {projected_vision_flat.size(0)}" ) expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat) inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, **lm_kwargs, ) logits = outputs[0] loss = None if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return Llama4CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, cache_position=None, logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values return model_inputs @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask __all__ = [ "Llama4PreTrainedModel", "Llama4TextModel", "Llama4VisionModel", "Llama4ForCausalLM", "Llama4ForConditionalGeneration", ]
Memory