diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 0ce947c30c..5ce7d5e936 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -338,6 +338,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float( metrics["scalar"].get("evaluation/moe_lb_loss", 0.0) ) + self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] += float( + metrics["scalar"].get("evaluation/indexer_loss", 0.0) + ) self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(metrics["scalar"].get("evaluation/mtp_loss", 0.0)) self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float( metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0) @@ -355,6 +358,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): self.cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = ( self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count ) + self.cumulative_eval_metrics["scalar"]["eval/avg_indexer_loss"] = ( + self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] / eval_step_count + ) self.cumulative_eval_metrics["scalar"]["eval/avg_mtp_loss"] = ( self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] / eval_step_count ) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 0635851cce..4b1b5f94d8 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -359,6 +359,13 @@ use_sparse_indexer: False index_head_dim: 128 index_n_heads: 64 index_topk: 2048 +# Determines the token selection strategy for indexer loss: +# - False: Uses all tokens (Dense Warm-up). +# - True: Uses only top-k tokens (Sparse Training). +# Note: This is only active when `indexer_loss_scaling_factor` > 0. +sparse_indexer_loss: False +# Multiplier for the indexer KL divergence loss +indexer_loss_scaling_factor: 0.0 # MLA parameters q_lora_rank: 0 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index f698f65477..15d8293cd5 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -534,6 +534,8 @@ class AttentionIndexer(BaseModel): index_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.") index_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.") index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.") + sparse_indexer_loss: bool = Field(False, description="Determines the token selection strategy for indexer loss.") + indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.") class Llama4Attention(BaseModel): diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index 58fe48de92..e0d6e4e9f1 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -73,6 +73,7 @@ from maxtext.inference import paged_attention from maxtext.inference.kvcache import KVQuant from maxtext.utils.sharding import create_sharding +from maxtext.utils.globals import EPS class Indexer(nnx.Module): @@ -246,10 +247,10 @@ def __call__( the inputs and configuration. Returns: - index_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens + indexer_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens and large negative values otherwise. topk_indices: Indices of the top-k selected tokens [b, t, k]. - index_score: The computed relevance scores [b, t, s]. + indexer_score: The computed relevance scores [b, t, s]. Notation: b: Batch size @@ -283,27 +284,27 @@ def __call__( logits = jax.nn.relu(logits) # Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h] weights = self.weights_proj(inputs_q) - # Weights scaling affect index_score, but does not affect topk_indices. Keep scaling for numerical stability. + # Weights scaling affect indexer_score, but does not affect topk_indices. Keep scaling for numerical stability. # https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480 weights = weights * (self.n_heads**-0.5) * self.softmax_scale # Aggregate head-wise logits: logits @ weights - index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s] + indexer_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s] # Apply attention mask before TopK if attention_mask is not None: - index_score += attention_mask + indexer_score += attention_mask # TopK selection based on index score - _, topk_indices = jax.lax.top_k(index_score, k=self.index_topk) # topk_indices [b, t, k] + _, topk_indices = jax.lax.top_k(indexer_score, k=self.index_topk) # topk_indices [b, t, k] # Create Sparse Index Mask: 0 and large negatives - index_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s] + indexer_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s] # Re-apply attention mask after TopK: in case number of unmasked tokens < TopK if attention_mask is not None: - index_mask += attention_mask + indexer_mask += attention_mask - return index_mask, topk_indices, index_score + return indexer_mask, topk_indices, indexer_score def mla_as_linen( @@ -951,6 +952,71 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm return key, value, cached_values + def calculate_indexer_loss( + self, + indexer_score: Array, + query: Array, + key: Array, + attention_mask: Optional[Array | None], + indexer_mask: Array, + sparse_loss: bool, + scaling_factor: float, + ) -> Array: + """Calculates the indexer KL divergence loss. + + This loss trains the indexer to predict which tokens are important by matching + the distribution of true attention scores from the main model. + + The target distribution is derived through the following steps: + 1. Compute raw attention scores via Q @ K^T. + 2. Aggregate scores by summing across all attention heads. + 3. Apply L1-normalization across the sequence dimension. + + target_distribution = L1_Normalize(Sum_h(Softmax(Q @ K^T))) + + Reference: + DeepSeek-V3.2 - https://arxiv.org/pdf/2512.02556 + + Args: + indexer_score: Scores predicted by indexer [batch, q_len, kv_len]. + query: Query tensor from main model [batch, q_len, heads, dim]. + key: Key tensor from main model [batch, kv_len, heads, dim]. + attention_mask: Attention mask [batch, q_len, kv_len] or None. + indexer_mask: Indexer mask [batch, q_len, kv_len]. + sparse_loss: Whether to use sparse loss. + scaling_factor: The scaling factor for the loss. + + Returns: + The computed KL divergence loss. + """ + # Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s] + attention_scores = jnp.einsum("bthd, bshd -> bhts", query, key, precision=self.config.matmul_precision) + + if sparse_loss: + # indexer_mask is already pre-filtered with the attention_mask if any + attention_scores = attention_scores + indexer_mask[:, None, :, :] + indexer_score = indexer_score + indexer_mask + elif attention_mask is not None: + # indexer_score already applies attention_mask; updating attention_scores only + attention_scores = attention_scores + attention_mask[:, None, :, :] + + # Use float32 for softmax numerical stability. + attention_probs = jax.nn.softmax(attention_scores.astype(jnp.float32), axis=-1) + indexer_probs = jax.nn.softmax(indexer_score.astype(jnp.float32), axis=-1) + + # Aggregate heads: [b, h, t, s] -> [b, t, s] + attention_probs = jnp.sum(attention_probs, axis=1) + # L1 normalize aggregated target distribution + attention_probs = attention_probs / (jnp.sum(attention_probs, axis=-1, keepdims=True) + EPS) + + # KL Divergence: KL(attention || indexer) + log_attention_probs = jnp.log(attention_probs + EPS) + log_indexer_probs = jnp.log(indexer_probs + EPS) + kl_per_token = attention_probs * (log_attention_probs - log_indexer_probs) + indexer_loss = jnp.mean(jnp.sum(kl_per_token, axis=-1)) + + return indexer_loss * scaling_factor + def __call__( self, inputs_q: Array, @@ -1013,16 +1079,18 @@ def __call__( value = checkpoint_name(value, "value_proj") # Indexer Logic - index_mask = None + indexer_mask = None if self.use_sparse_indexer: if model_mode != MODEL_MODE_TRAIN: raise NotImplementedError("Sparse indexer has not implemented for inference yet.") # generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len] attention_mask = self.attention_op.generate_attention_mask( query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask - ).squeeze(axis=(1, 2)) - # apply indexer, index_mask [b, q_len, kv_len] - index_mask, _, _ = self.indexer( + ) + if attention_mask is not None: + attention_mask = attention_mask.squeeze(axis=(1, 2)) + # apply indexer, indexer_mask [b, q_len, kv_len] + indexer_mask, _, indexer_score = self.indexer( inputs_q=inputs_q, low_rank_q=low_rank_q, inputs_kv=inputs_kv, @@ -1030,6 +1098,18 @@ def __call__( attention_mask=attention_mask, ) + if self.config.indexer_loss_scaling_factor > 0.0: + indexer_loss = self.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=self.config.sparse_indexer_loss, + scaling_factor=self.config.indexer_loss_scaling_factor, + ) + self.sow(nnx.Intermediate, "indexer_loss", indexer_loss) + # Check if we need QK Clip stats use_qk_clip = self.model_mode == MODEL_MODE_TRAIN and self.config.use_qk_clip @@ -1047,7 +1127,7 @@ def __call__( decoder_segment_ids, model_mode, cached_values, - index_mask=index_mask, + indexer_mask=indexer_mask, record_max_logits=use_qk_clip, ) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index e9ee57cbae..b2aed7b27e 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -884,7 +884,7 @@ def apply_attention( previous_chunk: Any = None, bidirectional_mask: Any = None, sinks: Array | None = None, - index_mask: Array | None = None, + indexer_mask: Array | None = None, record_max_logits: bool = False, *, qk_product_einsum: Callable[..., Array], @@ -929,7 +929,7 @@ def apply_attention( previous_chunk, bidirectional_mask=bidirectional_mask, sinks=sinks, - index_mask=index_mask, + indexer_mask=indexer_mask, record_max_logits=record_max_logits, qk_product_einsum=qk_product_einsum, wv_product_einsum=wv_product_einsum, @@ -1134,7 +1134,7 @@ def tpu_flash_attention( decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None, sinks: Array | None = None, - index_mask: Array | None = None, + indexer_mask: Array | None = None, record_max_logits: bool = False, ) -> tuple[Array, Array]: """TPU Flash Attention.""" @@ -1161,12 +1161,12 @@ def tpu_flash_attention( axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep) - index_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH)) + indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH)) else: axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) - index_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) + indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel @@ -1376,7 +1376,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): None, # no sharding for cp_size None, # no sharding for load_balanced_context_parallel sink_axis_names, # sharding align with query heads - index_mask_axis_names, + indexer_mask_axis_names, ), out_specs=out_specs, check_vma=False, @@ -1392,7 +1392,7 @@ def wrap_flash_attention( cp_size, load_balanced_context_parallel, sinks, - index_mask, + indexer_mask, ): # If load_balanced_context_parallel is enabled, reorder the key and value tensors # to ensure that they are contiguous in memory. @@ -1421,11 +1421,11 @@ def wrap_flash_attention( decoder_segment_ids_tuple = None if self.config.use_tokamax_splash: - if self.config.use_sparse_indexer and index_mask is not None: + if self.config.use_sparse_indexer and indexer_mask is not None: # Construct the splash kernel call with dynamic mask - def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask): + def dynamic_mask_splash_kernel(q, k, v, segment, sinks, indexer_mask): splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha( - mask=index_mask, + mask=indexer_mask, config=sa_config, ) kernel = partial(splash_kernel, max_logit_value=max_logit_value) @@ -1438,13 +1438,13 @@ def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask): # Iterate over batch dimension for (query, key, value, segment, sinks, mask) attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0)) - index_mask = jnp.isclose(index_mask, 0.0) + indexer_mask = jnp.isclose(indexer_mask, 0.0) if record_max_logits: - attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask) + attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask) return attention_output, max_logits else: - attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask) + attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask) return attention_output, None else: kernel = partial(splash_kernel, max_logit_value=max_logit_value) @@ -1509,7 +1509,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) - index_mask = _maybe_shard_with_pspec(index_mask, index_mask_axis_names) + indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names) ret = wrap_flash_attention( query, @@ -1522,7 +1522,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): cp_size, load_balanced_context_parallel, sinks, - index_mask, + indexer_mask, ) x, max_logits = ret @@ -1766,7 +1766,7 @@ def apply_attention_dot( previous_chunk: Any = None, bidirectional_mask: Any = None, sinks: Array | None = None, - index_mask: Array | None = None, + indexer_mask: Array | None = None, record_max_logits: bool = False, *, qk_product_einsum: Callable[..., Array], @@ -1846,11 +1846,11 @@ def apply_attention_dot( # Apply index mask, deepseek sparse attention # index mask contains 0.0 for kept tokens and large negative for masked tokens. - if index_mask is not None: - # index_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len] - index_mask = index_mask[:, None, None, :, :] + if indexer_mask is not None: + # indexer_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len] + indexer_mask = indexer_mask[:, None, None, :, :] # attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len] - attn_weights = apply_mask_to_logits(attn_weights, index_mask) + attn_weights = apply_mask_to_logits(attn_weights, indexer_mask) if self.is_partition_in_decode(q_seq_len): attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None)) @@ -2035,7 +2035,7 @@ def __call__( previous_chunk=None, bidirectional_mask=None, sinks=None, - index_mask: Optional[Array] = None, + indexer_mask: Optional[Array] = None, slot: Optional[int] = None, page_state: Optional[page_manager.PageState] = None, record_max_logits: bool = False, @@ -2059,7 +2059,7 @@ def __call__( previous_chunk=previous_chunk, bidirectional_mask=bidirectional_mask, sinks=sinks, - index_mask=index_mask, + indexer_mask=indexer_mask, record_max_logits=record_max_logits, qk_product_einsum=self.AqtEinsum_0, wv_product_einsum=self.AqtEinsum_1, diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 2cdaff130f..4b3505b224 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -218,6 +218,24 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): mtp_loss = calculate_mtp_loss(intermediate_outputs, config) loss += mtp_loss + # get indexer loss + indexer_loss = 0.0 + if config.use_sparse_indexer and config.indexer_loss_scaling_factor > 0.0: + indexer_losses = [] + # Extract 'indexer_loss' from model intermediates. + # We check for paths ending in ('self_attention', 'indexer_loss'). + # This handles varying paths caused by different layer names. + for path, val in jax.tree_util.tree_leaves_with_path(intermediate_outputs): + path_keys = tuple(k.key for k in path if hasattr(k, "key")) + if path_keys[-2:] == ("self_attention", "indexer_loss"): + indexer_losses.append(jnp.ravel(val)) + + if indexer_losses: + indexer_loss = jnp.mean(jnp.concatenate(indexer_losses)) + loss += indexer_loss + else: + max_logging.debug("No indexer loss found.") + # get MoE load balance loss moe_lb_loss = 0.0 if config.num_experts > 1: @@ -257,6 +275,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): "z_loss": total_z_loss, "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, + "indexer_loss": indexer_loss, "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, } @@ -327,6 +346,7 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] + indexer_loss = aux["indexer_loss"] z_loss = aux["z_loss"] moe_bias_updates = aux["moe_bias_updates"] mtp_loss = aux["mtp_loss"] @@ -373,6 +393,7 @@ def move(path, value): "learning/loss": loss, "learning/z_loss": z_loss, "learning/moe_lb_loss": moe_lb_loss, + "learning/indexer_loss": indexer_loss, "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } @@ -425,6 +446,7 @@ def eval_step(model, config, state, data, dropout_rng): z_loss = aux["z_loss"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] + indexer_loss = aux["indexer_loss"] mtp_loss = aux["mtp_loss"] metrics = { "scalar": { @@ -433,6 +455,7 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/total_loss": total_loss, "evaluation/total_weights": total_weights, "evaluation/moe_lb_loss": moe_lb_loss, + "evaluation/indexer_loss": indexer_loss, "evaluation/mtp_loss": mtp_loss, "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index 647d162041..e4cad14906 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -94,6 +94,7 @@ def accumulate_gradient(acc_grad_and_loss, data): (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) acc_grad_and_loss["loss"] += aux["total_loss"] acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] + acc_grad_and_loss["indexer_loss"] += aux["indexer_loss"] acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"] acc_grad_and_loss["grad"] = jax.tree_util.tree_map(lambda x, y: x + y, cur_batch_gradient, acc_grad_and_loss["grad"]) acc_grad_and_loss["total_weights"] += aux["total_weights"] @@ -114,6 +115,7 @@ def reshape_to_microbatch_accumulations(batch_arr): "grad": init_grad, "total_weights": 0, "moe_lb_loss": 0.0, + "indexer_loss": 0.0, "mtp_loss": 0.0, "ga_params": ga_params, } @@ -124,6 +126,7 @@ def reshape_to_microbatch_accumulations(batch_arr): loss = ( grad_and_loss["loss"] / grad_and_loss["total_weights"] + grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps + + grad_and_loss["indexer_loss"] / config.gradient_accumulation_steps + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = grad_and_loss["grad"] diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 6fc3177f7c..fc2c3c2d24 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -34,6 +34,7 @@ MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, + DEFAULT_MASK_VALUE, ) from maxtext.layers.attention_mla import MLA from maxtext.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask @@ -1591,6 +1592,107 @@ def test_tpu_flash_attention_context_parallel( f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", ) + def get_indexer_test_data(self, batch_size, q_len, kv_len, num_heads, head_dim): + """Helper to generate random data for indexer tests.""" + key_q, key_k, key_is = jax.random.split(self.rng, 3) + query = jax.random.normal(key_q, (batch_size, q_len, num_heads, head_dim)) + key = jax.random.normal(key_k, (batch_size, kv_len, num_heads, head_dim)) + indexer_score = jax.random.normal(key_is, (batch_size, q_len, kv_len)) + return query, key, indexer_score + + def get_causal_mask_for_indexer(self, batch_size, q_len, kv_len): + """Helper to generate a causal mask with DEFAULT_MASK_VALUE.""" + row_ids = jnp.arange(q_len)[:, None] + col_ids = jnp.arange(kv_len)[None, :] + attention_mask = jnp.where(col_ids <= row_ids, 0.0, DEFAULT_MASK_VALUE) + attention_mask = jnp.broadcast_to(attention_mask, (batch_size, q_len, kv_len)) + return attention_mask + + def test_indexer_loss(self): + """Test indexer loss computation.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_sparse_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + batch_size = 2 + q_len = 3 + kv_len = 4 + num_heads = 5 + head_dim = 6 + scaling_factor = 0.5 + + query, key, indexer_score = self.get_indexer_test_data(batch_size, q_len, kv_len, num_heads, head_dim) + + # Causal mask + attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) + indexer_score += attention_mask + + topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) + indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask + + loss_dense = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=False, + scaling_factor=scaling_factor, + ) + + loss_sparse = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=True, + scaling_factor=scaling_factor, + ) + + np.testing.assert_array_less(0.0, loss_dense) + np.testing.assert_array_less(0.0, loss_sparse) + + def test_indexer_loss_kl_divergence_zero(self): + """Test that KL divergence is 0 when target and pred distributions match exactly.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_sparse_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + batch_size = 2 + q_len = 3 + kv_len = 4 + num_heads = 5 + head_dim = 6 + + # Setup perfectly matching distributions + # Make query and key such that einsum yields zeros (so softmax gives uniform distribution over unmasked) + query = jnp.zeros((batch_size, q_len, num_heads, head_dim)) + key = jnp.zeros((batch_size, kv_len, num_heads, head_dim)) + + # Causal mask + attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) + + # Indexer score matches the shape and is uniform + indexer_score = jnp.zeros((batch_size, q_len, kv_len)) + attention_mask + + topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) + indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask + + loss = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=False, + scaling_factor=1.0, + ) + + np.testing.assert_allclose(loss, 0.0, atol=1e-5) + class Qwen3NextGatedDeltaNetTest(unittest.TestCase): """Test for the Gated Delta Net in Qwen3-Next"""