Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
e0ae107
initial implementation for mxfp8
cyanguwa Jan 31, 2026
23434b5
semi-working FP8; broken F16
cyanguwa Feb 4, 2026
dbb68b8
clean up last commit
cyanguwa Feb 4, 2026
c627231
comment out F16 pass
cyanguwa Feb 4, 2026
d27a267
Merge branch 'NVIDIA:main' into mxfp8_fwd
cyanguwa Feb 6, 2026
3f3b9e6
pull in grouped_quantize for MXFP8
cyanguwa Feb 6, 2026
850b16e
grouped tensor - pytorch
cyanguwa Feb 7, 2026
46f2eb1
quantize mxfp8
cyanguwa Feb 7, 2026
e86207c
fix shapes/strides
cyanguwa Feb 10, 2026
4e854d5
fix unfused; clean up
cyanguwa Feb 12, 2026
cd06398
split d to d_qk/d_v; attempt at bwd
cyanguwa Feb 13, 2026
d2a63a1
merge main
cyanguwa Feb 13, 2026
730a472
fix last merge
cyanguwa Feb 14, 2026
d9ff566
update FE
cyanguwa Feb 14, 2026
2b264d7
attempt at SWA/MLA
cyanguwa Feb 14, 2026
2008bed
remove prints
cyanguwa Feb 14, 2026
239f58a
remove leftover prints
cyanguwa Feb 14, 2026
f44a775
Revert "update FE"
cyanguwa Feb 14, 2026
965572b
update FE
cyanguwa Feb 14, 2026
91025c7
fix MLA O strides; add bottom_right_diagonal
cyanguwa Feb 17, 2026
d655e7e
attempt at bwd
cyanguwa Feb 18, 2026
a4ab691
fix get_quantizers; attempt at bwd
cyanguwa Feb 19, 2026
a85070d
fix fprop; add o_format
cyanguwa Feb 20, 2026
8909b35
attempt at bwd with o_format/d_out_format/dqkv_layout
cyanguwa Feb 20, 2026
90a636c
fix dtype/o_format/etc in bwd calls
cyanguwa Feb 21, 2026
8c72dea
fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8
cyanguwa Feb 21, 2026
5f23edd
fix upon last commit for paddedsizes
cyanguwa Feb 21, 2026
18c5580
add mxfp8 env var
cyanguwa Feb 21, 2026
6847645
disable FA for mxfp8
cyanguwa Feb 21, 2026
c5a98d5
add mha test
cyanguwa Feb 21, 2026
7e61ecd
attempt at bwd; force determinism; fix shapes
cyanguwa Feb 24, 2026
6d468da
remove prints
cyanguwa Feb 26, 2026
9f8e856
update FE
cyanguwa Feb 26, 2026
facef79
update FE from pre-merge branch to post-merge develop
cyanguwa Feb 26, 2026
fd33cca
allow MXFP8 linear + f16 attn
cyanguwa Feb 26, 2026
5079d55
test cp a2a
cyanguwa Feb 27, 2026
06b7d49
remove prints temporarily
cyanguwa Feb 27, 2026
7fbe399
test cp p2p
cyanguwa Feb 27, 2026
aa05a2a
minor fixes for mla
cyanguwa Feb 28, 2026
00e6693
open up a2a for mla
cyanguwa Feb 28, 2026
b8d28ce
test ag
cyanguwa Feb 28, 2026
d6ecadc
tweaks for last commit
cyanguwa Feb 28, 2026
3ac48cd
enable mla ag
cyanguwa Mar 1, 2026
169ae8a
merge main
cyanguwa Mar 1, 2026
5d4fa5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2026
81c18fa
fix merge
cyanguwa Mar 1, 2026
1f14f2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2026
ccebe77
fix merge
cyanguwa Mar 1, 2026
c52c5f4
revert to main grouped tensor impl
cyanguwa Mar 1, 2026
5b776ec
minor tweaks to return to main
cyanguwa Mar 1, 2026
4eee2bc
remove prints
cyanguwa Mar 3, 2026
8500121
fix combine_and_quantize for f16
cyanguwa Mar 3, 2026
0c2c466
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
6744aee
minor tweaks
cyanguwa Mar 3, 2026
4cec878
tweak tests
cyanguwa Mar 3, 2026
5c8e939
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
7b6b364
fix ds descale_o
cyanguwa Mar 3, 2026
462eb4f
Revert "fix ds descale_o"
cyanguwa Mar 3, 2026
77995d2
minor fixes for p2p and ag
cyanguwa Mar 7, 2026
586b698
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
url = https://github.com/google/googletest.git
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git
branch = develop
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated from 8d19d3 to b4370f
57 changes: 45 additions & 12 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
DotProductAttention,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
)
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
MXFP8BlockScaling,
Format,
)
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling
from utils import ModelConfig, compare_and_assert

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
Expand Down Expand Up @@ -180,19 +186,23 @@ def run_dpa_with_cp(
scaling_mode="delayed",
f16_O="False",
is_training="True",
deterministic="False",
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
logging.root.setLevel(log_level)
# When is_training is False, gradient outputs are None.
is_training = is_training == "True"

if deterministic == "True":
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
else:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
# set up environment variables and config
fp8_bwd = fp8_bwd == "True" and dtype == "fp8"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0"
fp8_dpa = fp8_dpa == "True" and dtype == "fp8"
fp8_mha = fp8_mha == "True" and dtype == "fp8"
f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True"
f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True"
os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
Expand Down Expand Up @@ -247,6 +257,10 @@ def run_dpa_with_cp(
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
if scaling_mode == "current":
fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
if scaling_mode == "mxfp8":
fp8_recipe = MXFP8BlockScaling(
fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha
)

# instantiate attention module
core_attn = DotProductAttention(
Expand Down Expand Up @@ -302,10 +316,25 @@ def run_dpa_with_cp(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
)
if scaling_mode == "mxfp8":
qkv_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
)
qkv_quantizer.optimize_for_gemm = True
qkv_quantizer.internal = False
dout_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
rowwise=True,
columnwise=True,
)
dout_quantizer.optimize_for_gemm = True
dout_quantizer.internal = False
qkv_layout = "_".join([qkv_format] * 3)
q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]]
if fp8_mha:
q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
if fp8_mha and scaling_mode != "mxfp8":
q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
for x in [q, k, v]:
x.requires_grad = True

Expand Down Expand Up @@ -351,12 +380,12 @@ def run_dpa_with_cp(
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
# fp8_output=fp8_mha,
)
if config.return_max_logit:
out, max_logit = out
if is_training:
if fp8_bwd and fp8_mha:
if fp8_bwd and fp8_mha and scaling_mode != "mxfp8":
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
else:
Expand Down Expand Up @@ -412,8 +441,8 @@ def run_dpa_with_cp(
qkv_quantizer.amax.fill_(0.0)
dout_quantizer.scale.fill_(1.0)
dout_quantizer.amax.fill_(0.0)
if fp8_mha:
q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
if fp8_mha and scaling_mode != "mxfp8":
q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
if is_training:
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
Expand Down Expand Up @@ -468,12 +497,12 @@ def run_dpa_with_cp(
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
# fp8_output=fp8_mha,
)
if config.return_max_logit:
out_, max_logit_ = out_
if is_training:
if fp8_bwd and fp8_mha:
if fp8_bwd and fp8_mha and scaling_mode != "mxfp8":
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
else:
Expand Down Expand Up @@ -502,9 +531,13 @@ def run_dpa_with_cp(
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[5] = tensors_to_deq
for tensor in tensors:
for i, tensor in enumerate(tensors):
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
print(
f"========= {torch.cuda.current_device()}: tensors[{i}].shape:"
f" {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}"
)
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors
Expand Down
64 changes: 45 additions & 19 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,21 +1804,30 @@ def get_model(dtype, config):

model_configs_fp8_vs_f16 = {
# test: ModelConfig(b, sq, hq, dqk)
"fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
"fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
"fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
"fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
"fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
"fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
"fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
"fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
"fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
"fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
"fp8_9": ModelConfig(
2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)
),
"fp8_10": ModelConfig(
2,
4096,
128,
192,
head_dim_v=128,
attn_mask_type="causal",
),
# "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
# "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
# "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
# "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
# "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
# "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
# "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
# "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
# "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
# "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
}

param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
param_types_fp8_vs_f16 = [torch.bfloat16] # [torch.float16, torch.bfloat16]
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

Expand All @@ -1832,7 +1841,7 @@ def get_model(dtype, config):
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"])
def test_mha_fp8_vs_f16(
dtype,
model,
Expand Down Expand Up @@ -1863,6 +1872,12 @@ def test_mha_fp8_vs_f16(
fp8_dpa=True,
fp8_mha=True,
)
elif scaling_mode == "mxfp8":
fp8_recipe = recipe.MXFP8BlockScaling(
fp8_format=recipe.Format.E4M3,
fp8_dpa=True,
fp8_mha=True,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, _ = get_available_attention_backends(
Expand Down Expand Up @@ -2046,7 +2061,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)

with autocast(enabled=fp8_mha, recipe=fp8_recipe):
out = mha(
hidden_states,
Expand Down Expand Up @@ -2082,7 +2096,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
"""Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model]
Expand Down Expand Up @@ -2114,6 +2128,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8_format=recipe.Format.HYBRID,
fp8_dpa=True,
)
elif scaling_mode == "mxfp8":
fp8_recipe = recipe.MXFP8BlockScaling(
fp8_format=recipe.Format.E4M3,
fp8_dpa=True,
fp8_mha=False,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, _ = get_available_attention_backends(
Expand Down Expand Up @@ -2274,7 +2294,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
with quantized_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim_qk,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
Expand Down Expand Up @@ -2319,7 +2339,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim_qk,
"dqk": config.head_dim_qk,
"dv": config.head_dim_v,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
Expand All @@ -2335,6 +2356,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
if i == 2:
layout = layout.replace("d", "dv")
else:
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
if config.dropout_p == 0.0:
tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
Expand All @@ -2359,6 +2384,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:

qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
qkv_format_kv = qkv_format_kv.replace("d", "dv")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
Expand All @@ -2369,14 +2395,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
inp[1],
inp[2],
qkv_format=qkv_format,
window_size=config.window_size,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
fp8_output=fp8_dpa,
)
if is_training:
out.backward(out_grad)
Expand Down
Loading