Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None):
self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(
metrics["scalar"].get("evaluation/moe_lb_loss", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] += float(
metrics["scalar"].get("evaluation/indexer_loss", 0.0)
)
self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(metrics["scalar"].get("evaluation/mtp_loss", 0.0))
self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float(
metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0)
Expand All @@ -355,6 +358,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None):
self.cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count
)
self.cumulative_eval_metrics["scalar"]["eval/avg_indexer_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] / eval_step_count
)
self.cumulative_eval_metrics["scalar"]["eval/avg_mtp_loss"] = (
self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] / eval_step_count
)
Expand Down
7 changes: 7 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ use_sparse_indexer: False
index_head_dim: 128
index_n_heads: 64
index_topk: 2048
# Determines the token selection strategy for indexer loss:
# - False: Uses all tokens (Dense Warm-up).
# - True: Uses only top-k tokens (Sparse Training).
# Note: This is only active when `indexer_loss_scaling_factor` > 0.
sparse_indexer_loss: False
# Multiplier for the indexer KL divergence loss
indexer_loss_scaling_factor: 0.0

# MLA parameters
q_lora_rank: 0
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ class AttentionIndexer(BaseModel):
index_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
index_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
sparse_indexer_loss: bool = Field(False, description="Determines the token selection strategy for indexer loss.")
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")


class Llama4Attention(BaseModel):
Expand Down
108 changes: 94 additions & 14 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from maxtext.inference import paged_attention
from maxtext.inference.kvcache import KVQuant
from maxtext.utils.sharding import create_sharding
from maxtext.utils.globals import EPS


class Indexer(nnx.Module):
Expand Down Expand Up @@ -246,10 +247,10 @@ def __call__(
the inputs and configuration.
Returns:
index_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
indexer_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
and large negative values otherwise.
topk_indices: Indices of the top-k selected tokens [b, t, k].
index_score: The computed relevance scores [b, t, s].
indexer_score: The computed relevance scores [b, t, s].
Notation:
b: Batch size
Expand Down Expand Up @@ -283,27 +284,27 @@ def __call__(
logits = jax.nn.relu(logits)
# Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h]
weights = self.weights_proj(inputs_q)
# Weights scaling affect index_score, but does not affect topk_indices. Keep scaling for numerical stability.
# Weights scaling affect indexer_score, but does not affect topk_indices. Keep scaling for numerical stability.
# https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
# Aggregate head-wise logits: logits @ weights
index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]
indexer_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]

# Apply attention mask before TopK
if attention_mask is not None:
index_score += attention_mask
indexer_score += attention_mask

# TopK selection based on index score
_, topk_indices = jax.lax.top_k(index_score, k=self.index_topk) # topk_indices [b, t, k]
_, topk_indices = jax.lax.top_k(indexer_score, k=self.index_topk) # topk_indices [b, t, k]

# Create Sparse Index Mask: 0 and large negatives
index_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
indexer_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]

# Re-apply attention mask after TopK: in case number of unmasked tokens < TopK
if attention_mask is not None:
index_mask += attention_mask
indexer_mask += attention_mask

return index_mask, topk_indices, index_score
return indexer_mask, topk_indices, indexer_score


def mla_as_linen(
Expand Down Expand Up @@ -951,6 +952,71 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm

return key, value, cached_values

def calculate_indexer_loss(
self,
indexer_score: Array,
query: Array,
key: Array,
attention_mask: Optional[Array | None],
indexer_mask: Array,
sparse_loss: bool,
scaling_factor: float,
) -> Array:
"""Calculates the indexer KL divergence loss.
This loss trains the indexer to predict which tokens are important by matching
the distribution of true attention scores from the main model.
The target distribution is derived through the following steps:
1. Compute raw attention scores via Q @ K^T.
2. Aggregate scores by summing across all attention heads.
3. Apply L1-normalization across the sequence dimension.
target_distribution = L1_Normalize(Sum_h(Softmax(Q @ K^T)))
Reference:
DeepSeek-V3.2 - https://arxiv.org/pdf/2512.02556
Args:
indexer_score: Scores predicted by indexer [batch, q_len, kv_len].
query: Query tensor from main model [batch, q_len, heads, dim].
key: Key tensor from main model [batch, kv_len, heads, dim].
attention_mask: Attention mask [batch, q_len, kv_len] or None.
indexer_mask: Indexer mask [batch, q_len, kv_len].
sparse_loss: Whether to use sparse loss.
scaling_factor: The scaling factor for the loss.
Returns:
The computed KL divergence loss.
"""
# Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s]
attention_scores = jnp.einsum("bthd, bshd -> bhts", query, key, precision=self.config.matmul_precision)

if sparse_loss:
# indexer_mask is already pre-filtered with the attention_mask if any
attention_scores = attention_scores + indexer_mask[:, None, :, :]
indexer_score = indexer_score + indexer_mask
elif attention_mask is not None:
# indexer_score already applies attention_mask; updating attention_scores only
attention_scores = attention_scores + attention_mask[:, None, :, :]

# Use float32 for softmax numerical stability.
attention_probs = jax.nn.softmax(attention_scores.astype(jnp.float32), axis=-1)
indexer_probs = jax.nn.softmax(indexer_score.astype(jnp.float32), axis=-1)

# Aggregate heads: [b, h, t, s] -> [b, t, s]
attention_probs = jnp.sum(attention_probs, axis=1)
# L1 normalize aggregated target distribution
attention_probs = attention_probs / (jnp.sum(attention_probs, axis=-1, keepdims=True) + EPS)

# KL Divergence: KL(attention || indexer)
log_attention_probs = jnp.log(attention_probs + EPS)
log_indexer_probs = jnp.log(indexer_probs + EPS)
kl_per_token = attention_probs * (log_attention_probs - log_indexer_probs)
indexer_loss = jnp.mean(jnp.sum(kl_per_token, axis=-1))

return indexer_loss * scaling_factor

def __call__(
self,
inputs_q: Array,
Expand Down Expand Up @@ -1013,23 +1079,37 @@ def __call__(
value = checkpoint_name(value, "value_proj")

# Indexer Logic
index_mask = None
indexer_mask = None
if self.use_sparse_indexer:
if model_mode != MODEL_MODE_TRAIN:
raise NotImplementedError("Sparse indexer has not implemented for inference yet.")
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
attention_mask = self.attention_op.generate_attention_mask(
query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask
).squeeze(axis=(1, 2))
# apply indexer, index_mask [b, q_len, kv_len]
index_mask, _, _ = self.indexer(
)
if attention_mask is not None:
attention_mask = attention_mask.squeeze(axis=(1, 2))
# apply indexer, indexer_mask [b, q_len, kv_len]
indexer_mask, _, indexer_score = self.indexer(
inputs_q=inputs_q,
low_rank_q=low_rank_q,
inputs_kv=inputs_kv,
inputs_positions=inputs_positions,
attention_mask=attention_mask,
)

if self.config.indexer_loss_scaling_factor > 0.0:
indexer_loss = self.calculate_indexer_loss(
indexer_score=indexer_score,
query=query,
key=key,
attention_mask=attention_mask,
indexer_mask=indexer_mask,
sparse_loss=self.config.sparse_indexer_loss,
scaling_factor=self.config.indexer_loss_scaling_factor,
)
self.sow(nnx.Intermediate, "indexer_loss", indexer_loss)

# Check if we need QK Clip stats
use_qk_clip = self.model_mode == MODEL_MODE_TRAIN and self.config.use_qk_clip

Expand All @@ -1047,7 +1127,7 @@ def __call__(
decoder_segment_ids,
model_mode,
cached_values,
index_mask=index_mask,
indexer_mask=indexer_mask,
record_max_logits=use_qk_clip,
)

Expand Down
44 changes: 22 additions & 22 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def apply_attention(
previous_chunk: Any = None,
bidirectional_mask: Any = None,
sinks: Array | None = None,
index_mask: Array | None = None,
indexer_mask: Array | None = None,
record_max_logits: bool = False,
*,
qk_product_einsum: Callable[..., Array],
Expand Down Expand Up @@ -929,7 +929,7 @@ def apply_attention(
previous_chunk,
bidirectional_mask=bidirectional_mask,
sinks=sinks,
index_mask=index_mask,
indexer_mask=indexer_mask,
record_max_logits=record_max_logits,
qk_product_einsum=qk_product_einsum,
wv_product_einsum=wv_product_einsum,
Expand Down Expand Up @@ -1134,7 +1134,7 @@ def tpu_flash_attention(
decoder_segment_ids: Array | None,
attn_logits_soft_cap: float | None = None,
sinks: Array | None = None,
index_mask: Array | None = None,
indexer_mask: Array | None = None,
record_max_logits: bool = False,
) -> tuple[Array, Array]:
"""TPU Flash Attention."""
Expand All @@ -1161,12 +1161,12 @@ def tpu_flash_attention(
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep)
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep)
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep)
index_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH))
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH))
else:
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
index_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))

global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
None, # no sharding for cp_size
None, # no sharding for load_balanced_context_parallel
sink_axis_names, # sharding align with query heads
index_mask_axis_names,
indexer_mask_axis_names,
),
out_specs=out_specs,
check_vma=False,
Expand All @@ -1392,7 +1392,7 @@ def wrap_flash_attention(
cp_size,
load_balanced_context_parallel,
sinks,
index_mask,
indexer_mask,
):
# If load_balanced_context_parallel is enabled, reorder the key and value tensors
# to ensure that they are contiguous in memory.
Expand Down Expand Up @@ -1421,11 +1421,11 @@ def wrap_flash_attention(
decoder_segment_ids_tuple = None

if self.config.use_tokamax_splash:
if self.config.use_sparse_indexer and index_mask is not None:
if self.config.use_sparse_indexer and indexer_mask is not None:
# Construct the splash kernel call with dynamic mask
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask):
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, indexer_mask):
splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(
mask=index_mask,
mask=indexer_mask,
config=sa_config,
)
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
Expand All @@ -1438,13 +1438,13 @@ def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask):

# Iterate over batch dimension for (query, key, value, segment, sinks, mask)
attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0))
index_mask = jnp.isclose(index_mask, 0.0)
indexer_mask = jnp.isclose(indexer_mask, 0.0)

if record_max_logits:
attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask)
attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask)
return attention_output, max_logits
else:
attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask)
attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask)
return attention_output, None
else:
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
index_mask = _maybe_shard_with_pspec(index_mask, index_mask_axis_names)
indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names)

ret = wrap_flash_attention(
query,
Expand All @@ -1522,7 +1522,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
cp_size,
load_balanced_context_parallel,
sinks,
index_mask,
indexer_mask,
)

x, max_logits = ret
Expand Down Expand Up @@ -1766,7 +1766,7 @@ def apply_attention_dot(
previous_chunk: Any = None,
bidirectional_mask: Any = None,
sinks: Array | None = None,
index_mask: Array | None = None,
indexer_mask: Array | None = None,
record_max_logits: bool = False,
*,
qk_product_einsum: Callable[..., Array],
Expand Down Expand Up @@ -1846,11 +1846,11 @@ def apply_attention_dot(

# Apply index mask, deepseek sparse attention
# index mask contains 0.0 for kept tokens and large negative for masked tokens.
if index_mask is not None:
# index_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
index_mask = index_mask[:, None, None, :, :]
if indexer_mask is not None:
# indexer_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
indexer_mask = indexer_mask[:, None, None, :, :]
# attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len]
attn_weights = apply_mask_to_logits(attn_weights, index_mask)
attn_weights = apply_mask_to_logits(attn_weights, indexer_mask)

if self.is_partition_in_decode(q_seq_len):
attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None))
Expand Down Expand Up @@ -2035,7 +2035,7 @@ def __call__(
previous_chunk=None,
bidirectional_mask=None,
sinks=None,
index_mask: Optional[Array] = None,
indexer_mask: Optional[Array] = None,
slot: Optional[int] = None,
page_state: Optional[page_manager.PageState] = None,
record_max_logits: bool = False,
Expand All @@ -2059,7 +2059,7 @@ def __call__(
previous_chunk=previous_chunk,
bidirectional_mask=bidirectional_mask,
sinks=sinks,
index_mask=index_mask,
indexer_mask=indexer_mask,
record_max_logits=record_max_logits,
qk_product_einsum=self.AqtEinsum_0,
wv_product_einsum=self.AqtEinsum_1,
Expand Down
Loading
Loading