diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 8d19d3182b..b4370f5198 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 8d19d3182bfbc304046a15e9236bec9ff31511fc +Subproject commit b4370f5198bd95ee758ebc2c6b76b887914b702d diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0f36a8816d..242d6b9e7a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -19,8 +19,14 @@ DotProductAttention, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, +) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -180,19 +186,23 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + deterministic="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) # When is_training is False, gradient outputs are None. is_training = is_training == "True" - + if deterministic == "True": + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" fp8_dpa = fp8_dpa == "True" and dtype == "fp8" fp8_mha = fp8_mha == "True" and dtype == "fp8" - f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True" os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -247,6 +257,10 @@ def run_dpa_with_cp( fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + ) # instantiate attention module core_attn = DotProductAttention( @@ -302,10 +316,25 @@ def run_dpa_with_cp( fp8_dtype=tex.DType.kFloat8E5M2, device="cuda", ) + if scaling_mode == "mxfp8": + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: - q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + if fp8_mha and scaling_mode != "mxfp8": + q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -351,12 +380,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out, max_logit = out if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: @@ -412,8 +441,8 @@ def run_dpa_with_cp( qkv_quantizer.amax.fill_(0.0) dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) - if fp8_mha: - q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + if fp8_mha and scaling_mode != "mxfp8": + q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: @@ -468,12 +497,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out_, max_logit_ = out_ if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: @@ -502,9 +531,13 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: + print( + f"========= {torch.cuda.current_device()}: tensors[{i}].shape:" + f" {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}" + ) assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 31c7041897..7ae73a753a 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,21 +1804,30 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128), - "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), - "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), - "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "fp8_9": ModelConfig( + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), + "fp8_10": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type="causal", + ), + # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + # "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + # "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + # "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + # "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), + # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } -param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] +param_types_fp8_vs_f16 = [torch.bfloat16] # [torch.float16, torch.bfloat16] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] @@ -1832,7 +1841,7 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_mha_fp8_vs_f16( dtype, model, @@ -1863,6 +1872,12 @@ def test_mha_fp8_vs_f16( fp8_dpa=True, fp8_mha=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=True, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2046,7 +2061,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, @@ -2082,7 +2096,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -2114,6 +2128,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8_format=recipe.Format.HYBRID, fp8_dpa=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2274,7 +2294,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -2319,7 +2339,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim_qk, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -2335,6 +2356,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] if config.dropout_p == 0.0: tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") @@ -2359,6 +2384,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") @@ -2369,6 +2395,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: inp[1], inp[2], qkv_format=qkv_format, + window_size=config.window_size, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=config.max_seqlen_q, @@ -2376,7 +2403,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index ecd0090a3b..116d4dcc41 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -17,6 +17,8 @@ from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils @@ -26,6 +28,12 @@ pytest_logging_level = logging.getLevelName(logging.root.level) +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) @@ -94,25 +102,34 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config.context_parallel = True config.cp_comm_type = cp_comm_type - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No support for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if ( + config.window_size != (-1, 0) + and config.window_size != (-1, -1) + and cp_comm_type + in [ + "p2p", + "a2a+p2p", + ] + ): + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + # FlashAttention / CP implementation specific: MLA only with KV P2P if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} @@ -151,8 +168,12 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_0": ModelConfig( + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), # GQA + "cp_2_1": ModelConfig( + 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" + ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -190,7 +211,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA - "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA @@ -215,21 +236,23 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: configs = [ - "cp_1_0", - "cp_1_1", - "cp_1_4", - "cp_1_5", + # "cp_1_0", + # "cp_1_1", + # "cp_1_4", + # "cp_1_5", "cp_2_0", - "cp_2_2", - "cp_2_3", - "cp_2_4", - "cp_3_2", - "cp_3_4", - "cp_4_2", + "cp_2_1", + # "cp_2_2", + # "cp_2_3", + # "cp_2_4", + # "cp_3_1", + # "cp_3_2", + # "cp_3_4", + # "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} - dtypes = ["bf16", "fp8"] - qkv_formats = ["sbhd", "thd"] + dtypes = ["fp8"] # ["bf16", "fp8"] + qkv_formats = ["bshd"] # , "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -241,96 +264,81 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) -@pytest.mark.parametrize("f16_O", [True, False]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) +@pytest.mark.parametrize("f16_O", [True]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): + config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") - if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+!") - if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") - if dtype == "fp8" and get_device_compute_capability() < (9, 0): - pytest.skip("FP8 attention is only supported on sm90+!") + if get_device_compute_capability() < (9, 0) and qkv_format == "thd": + pytest.skip("Only sm90+ architectures support THD format!") + if get_device_compute_capability() < (9, 0) and dtype == "fp8": + pytest.skip("Only sm90+ architectures support FP8 attention!") + + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("dtype=fp8 requires fp8_dpa=True or fp8_mha=True!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: - pytest.skip("Only fp8 works with fp8_bwd=True!") - - config = model_configs_fused_attn[model] - config.context_parallel = True - config.cp_comm_type = cp_comm_type + pytest.skip("fp8_bwd=True requires dtype=fp8!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("dtype!=fp8 requires fp8_dpa=False and fp8_mha=False!") - if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip("THD format does not support post_scale_bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if dtype == "fp8" and cp_comm_type == "all_gather": - pytest.skip( - "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - ) if dtype == "fp8" and qkv_format == "thd": - pytest.skip("FP8 attention cannot work with THD format yet!") + pytest.skip("No support for FP8 attention with THD format!") if dtype == "fp8" and config.attn_bias_type != "no_bias": - pytest.skip("FP8 attention cannot work with bias yet!") - if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("FP8 attention cannot work with sliding window yet!") - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): - pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" - ) - if dtype != "fp8" and (fp8_mha or fp8_dpa): - pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") - if dtype == "fp8" and not (fp8_mha or fp8_dpa): - pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") - if dtype != "fp8" and scaling_mode is not None: - pytest.skip("Only fp8 works with scaling_mode != None!") - if dtype == "fp8" and scaling_mode is None: - pytest.skip("fp8 only works with scaling_mode != None!") - if ( - dtype == "fp8" - and scaling_mode == "current" - and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + pytest.skip("No support for FP8 attention with bias!") + + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No supprt for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 ): - pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode != "current"): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently only support KV P2P!") - if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently does not support FP8 attention!") - if dtype == "fp8" and config.softmax_type != "vanilla": - pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") - if config.softmax_type != "vanilla" and cp_comm_type != "a2a": pytest.skip( - "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + if config.softmax_type != "vanilla" and dtype == "fp8": + pytest.skip("No support for non-vanilla softmax with FP8 attention!") + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") if ( - get_cudnn_version() < (9, 18, 0) - and config.softmax_type != "vanilla" + config.softmax_type != "vanilla" and qkv_format == "thd" + and get_cudnn_version() < (9, 18, 0) ): - pytest.skip( - "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" - " non-vanilla softmax types!" - ) + pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") + + if dtype == "fp8" and scaling_mode is None: + pytest.skip("dtype=fp8 requires scaling_mode != None!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("dtype!=fp8 requires scaling_mode = None!") + if dtype != "fp8" and not f16_O: + pytest.skip("dtype!=fp8 requires f16_O=True!") + if scaling_mode == "delayed" and f16_O: + pytest.skip("scaling_mode=delayed requires f16_O=False!") + if scaling_mode == "mxfp8" and not f16_O: + pytest.skip("scaling_mode=mxfp8 requires f16_O=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -354,6 +362,12 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + if fp8 and scaling_mode == "mxfp8": + fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True) + fp8_meta["local_recipes"] = [ + MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True), + ] + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -363,6 +377,7 @@ def test_cp_with_fused_attention( fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -382,6 +397,7 @@ def test_cp_with_fused_attention( scaling_mode=scaling_mode, f16_O=f16_O, is_training=is_training, + deterministic=_deterministic, log_level=pytest_logging_level, ), check=True, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 317240fb78..1747d75676 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -177,6 +177,7 @@ def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): rmse = torch.sqrt((a - b).square().mean()).item() logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + # rmse_tol = rmse_tol * 1.1 assert rmse < rmse_tol * rmse_range, ( name_a + " vs " diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index abdce7fdac..72c5273a78 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -49,6 +49,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_SD_SD_SD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -89,6 +91,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -108,6 +112,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -127,11 +133,92 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } } +// map one NVTE_QKV_Format to another +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t) { + size_t _b = 0, _h = 0, _s = 0, _d = 0, _t = 0; + switch (src_format) { + case NVTE_QKV_Format::NVTE_BSHD: + _b = src_shape[0]; + _s = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_SBHD: + _s = src_shape[0]; + _b = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_BHSD: + _b = src_shape[0]; + _h = src_shape[1]; + _s = src_shape[2]; + _d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_THD: + _t = src_shape[0]; + _h = src_shape[1]; + _d = src_shape[2]; + break; + default: + NVTE_ERROR("src_format not supported!"); + break; + } + switch (dst_format) { + case NVTE_QKV_Format::NVTE_BSHD: + dst_shape[0] = _b; + dst_shape[1] = _s; + dst_shape[2] = _h; + dst_shape[3] = _d; + break; + case NVTE_QKV_Format::NVTE_SBHD: + dst_shape[0] = _s; + dst_shape[1] = _b; + dst_shape[2] = _h; + dst_shape[3] = _d; + break; + case NVTE_QKV_Format::NVTE_BHSD: + dst_shape[0] = _b; + dst_shape[1] = _h; + dst_shape[2] = _s; + dst_shape[3] = _d; + break; + case NVTE_QKV_Format::NVTE_THD: + dst_shape[0] = _t; + dst_shape[1] = _h; + dst_shape[2] = _d; + break; + default: + NVTE_ERROR("dst_format not supported!"); + break; + } + + if (b != nullptr) { + *b = _b; + } + if (h != nullptr) { + *h = _h; + } + if (s != nullptr) { + *s = _s; + } + if (d != nullptr) { + *d = _d; + } + if (t != nullptr) { + *t = _t; + } +} + // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, @@ -183,8 +270,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: mxfp8, d_qk=128, d_v=192 + (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -324,12 +412,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -340,7 +431,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && window_size_right == -1 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + ((window_size_left == -1 || window_size_left >= 0) && + (window_size_right == -1 || window_size_right >= 0) && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && @@ -449,19 +541,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -483,8 +573,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; @@ -497,6 +585,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] + : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] + : input_K->data.shape[ndim_kv - 2]; int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -560,9 +652,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, + attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, + input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); @@ -580,11 +673,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -608,8 +702,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; @@ -622,6 +714,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] + : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] + : input_K->data.shape[ndim_kv - 2]; auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -672,11 +768,17 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K, - input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, - output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + const Tensor *input_dO_f16; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + } + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, + input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 80e64370f9..9796e39ddc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,16 +1652,19 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, + void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -1675,13 +1678,18 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); try { FADescriptor_v1 descriptor{b, @@ -1689,8 +1697,8 @@ void fused_attn_fp8_fwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -1704,13 +1712,13 @@ void fused_attn_fp8_fwd_impl_v1( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + window_size_left, + window_size_right, + bottom_right_diagonal, true, qkv_tensor_type, o_tensor_type, @@ -1765,28 +1773,31 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + // Q, K, V, attn_scale std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -1794,21 +1805,58 @@ void fused_attn_fp8_fwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - - if (is_delayed_scaling) { - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_o"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } } - if (is_current_scaling) { - scale_o = mha_graph->tensor(1.0f); + if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, + v_scale_strides.data(), kv_format, false); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_attributes sdpa_options; @@ -1818,6 +1866,18 @@ void fused_attn_fp8_fwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } + // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1855,23 +1915,36 @@ void fused_attn_fp8_fwd_impl_v1( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( - Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + std::shared_ptr O, Stats, amax_s, amax_o; + if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_o = outputs[2]; + } else { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_s = outputs[2]; + amax_o = outputs[3]; + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); + O->set_output(true) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - Stats->set_output(true) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h, s_q, 1}) @@ -1890,8 +1963,11 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, attn_scale, O, amax_s, amax_o); + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, + nullptr, attn_scale, O, nullptr, amax_o) + : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = @@ -1904,7 +1980,6 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); @@ -1937,17 +2012,19 @@ void fused_attn_fp8_fwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_s, devPtrDescaleS}, - {scale_s, devPtrScaleS}, {attn_scale, &scaling_factor}, {O, devPtrO}, - {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; if (is_delayed_scaling) { variant_pack[scale_o] = devPtrScaleO; } + if (!is_mxfp8) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[amax_s] = devPtrAmaxS; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -1971,7 +2048,6 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } - NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -1980,20 +2056,25 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, - void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, - void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, - void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, - void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, + void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, + void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, + void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, + void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, + void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, + void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -2003,18 +2084,23 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; - const auto cudnn_runtime_version = cudnnGetVersion(); auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2024,8 +2110,8 @@ void fused_attn_fp8_bwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -2039,13 +2125,13 @@ void fused_attn_fp8_bwd_impl_v1( scaling_factor, true, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + window_size_left, + window_size_right, + bottom_right_diagonal, deterministic, qkv_tensor_type, o_tensor_type, @@ -2056,18 +2142,25 @@ void fused_attn_fp8_bwd_impl_v1( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::shared_ptr, // Q + std::shared_ptr, // Q_t + std::shared_ptr, // K + std::shared_ptr, // K_t + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2108,54 +2201,59 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, + attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, + descale_v; std::shared_ptr descale_s, descale_o; - std::shared_ptr descale_dP, descale_dO; + std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; std::shared_ptr scale_dQ, scale_dK, scale_dV; std::shared_ptr bias, dBias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + // Q, K, V, O, dO, stats, attn_scale std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); + O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(do_tensor_type)); + Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") .set_dim({b, h, s_q, 1}) .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -2163,33 +2261,138 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { - descale_o = mha_graph->tensor(1.0f); - } else { - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); - } - descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - - if (is_delayed_scaling) { - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Descale_dP, Scale_dP, Descale_o, Descale_dO, Scale_dQ, Scale_dK, Scale_dV + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + if (is_current_scaling && is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } } - if (is_current_scaling) { - scale_dQ = mha_graph->tensor(1.0f); - scale_dK = mha_graph->tensor(1.0f); - scale_dV = mha_graph->tensor(1.0f); + if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + // Q_t, K_t, dO_t, dO_f16 + std::vector q_t_stride(4); + std::vector k_t_stride(4); + std::vector dO_t_stride(4); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); + Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_stride) + .set_data_type(qkv_tensor_type)); + K_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_stride) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_stride) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type)); + // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + std::vector q_scale_strides(4); + std::vector q_t_scale_strides(4); + std::vector k_scale_strides(4); + std::vector k_t_scale_strides(4); + std::vector v_scale_strides(4); + std::vector dO_scale_strides(4); + std::vector dO_t_scale_strides(4); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, + q_t_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, + k_t_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, + v_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, + dO_scale_strides.data(), d_out_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, + dO_t_scale_strides.data(), d_out_format, false); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_q_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k_t") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) + .set_stride(k_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO") + .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) + .set_stride(dO_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) + .set_stride(dO_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; @@ -2198,6 +2401,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } + // sdpa_backward_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2250,14 +2465,52 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( - q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); - - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; + if (is_delayed_scaling || is_current_scaling) { + auto outputs = mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, + scale_dV, scale_dP, sdpa_backward_options); + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + amax_dP = outputs[6]; + } + if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8_backward( + Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, + descale_k_t, descale_v, descale_dO, descale_dO_t, sdpa_backward_options); + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + } + std::vector dq_stride(4); + std::vector dk_stride(4); + std::vector dv_stride(4); + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, dq_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dk_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + dQ->set_output(true) + .set_dim({b, h, s_q, d_qk}) + .set_stride(dq_stride) + .set_data_type(dqkv_tensor_type); + dK->set_output(true) + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(dk_stride) + .set_data_type(dqkv_tensor_type); + dV->set_output(true) + .set_dim({b, hg, s_kv, d_v}) + .set_stride(dv_stride) + .set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2270,21 +2523,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - - dO->set_data_type(do_tensor_type); - dQ->set_data_type(dqkv_tensor_type); - dK->set_data_type(dqkv_tensor_type); - dV->set_data_type(dqkv_tensor_type); + if (!is_mxfp8) { + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } - std::tuple, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO std::shared_ptr, // attn_scale std::shared_ptr, // descale_q @@ -2307,9 +2557,12 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dV std::shared_ptr> // amax_dP key_tensors_tuple = std::make_tuple( - q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); + auto mxfp8_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2322,17 +2575,18 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, + bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - - auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, + descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = + get_graph(sdpa_fp8_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2349,39 +2603,49 @@ void fused_attn_fp8_bwd_impl_v1( // build variant pack std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrK}, - {v, devPtrV}, - {o, devPtrO}, - {stats, devPtrM}, + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {Stats, devPtrM}, {dO, devPtrdO}, {attn_scale, &scaling_factor}, {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, {descale_dO, devPtrDescaledO}, - {descale_s, devPtrDescaleS}, - {descale_dP, devPtrDescaledP}, - {scale_s, devPtrScaleS}, - {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, {amax_dQ, devPtrAmaxdQ}, {amax_dK, devPtrAmaxdK}, {amax_dV, devPtrAmaxdV}, - {amax_dP, devPtrAmaxdP}, }; - + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + } + if (is_delayed_scaling || (is_current_scaling && !is_O_in_F16)) { + variant_pack[descale_o] = devPtrDescaleO; + } if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (!is_O_in_F16) { - variant_pack[descale_o] = devPtrDescaleO; + if (is_mxfp8) { + variant_pack[Q_t] = devPtrQ_t; + variant_pack[K_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + // variant_pack[descale_dO] = devPtrDescaledO; + variant_pack[descale_dO_t] = devPtrDescaledO_t; } - /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { @@ -2423,26 +2687,54 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, + const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, + Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, + const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = input_Q->data.dptr; - void* devPtrK = input_K->data.dptr; - void* devPtrV = input_V->data.dptr; - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - + void* devPtrQ = nullptr; + void* devPtrK = nullptr; + void* devPtrV = nullptr; + void* devPtrDescaleQ = nullptr; + void* devPtrDescaleK = nullptr; + void* devPtrDescaleV = nullptr; + void* devPtrO = nullptr; + void* devPtrAmaxO = nullptr; + void* devPtrScaleO = nullptr; + void* devPtrAmaxS = nullptr; + void* devPtrScaleS = nullptr; + void* devPtrDescaleS = nullptr; + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + // devPtrV = input_V->data.dptr; + // devPtrDescaleV = input_V->scale_inv.dptr; + devPtrV = input_V->columnwise_data.dptr; + devPtrDescaleV = input_V->columnwise_scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + } else { + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { @@ -2470,10 +2762,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); void* devPtrcuSeqlensKV = @@ -2488,17 +2776,20 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, + devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, workspace->data.dptr, &workspace_size, + stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, @@ -2522,16 +2813,19 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, - const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { + const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, + const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, + Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, + const Tensor* output_dV, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2539,6 +2833,10 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleQ = input_Q->scale_inv.dptr; void* devPtrDescaleK = input_Q->scale_inv.dptr; void* devPtrDescaleV = input_Q->scale_inv.dptr; + void* devPtrQ_t = input_Q->columnwise_data.dptr; + void* devPtrK_t = input_K->columnwise_data.dptr; + void* devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; + void* devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; void* devPtrO = input_O->data.dptr; const DType O_type = input_O->data.dtype; @@ -2548,6 +2846,9 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; + void* devPtrdO_t = input_dO->columnwise_data.dptr; + void* devPtrdO_f16 = input_dO_f16->data.dptr; + void* devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; void* devPtrM = input_M->data.dptr; void* devPtrZInv = input_ZInv->data.dptr; @@ -2582,21 +2883,25 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, - devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, + devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, + devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, + devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), + get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, + &workspace_size, stream, handle); + } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 225e700eff..98d5876ec8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,26 +15,31 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, + cudnnHandle_t handle); // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, + const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, + Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, + const Tensor *output_dV, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index a897b09330..e67ae5e206 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,6 +293,32 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_transpose_dim_idx] = d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_transpose_dim_idx] = d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 08a56cda6b..3e4ca696e2 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -24,14 +24,330 @@ using namespace transformer_engine; enum NVTE_QKV_Matrix { NVTE_Q_Matrix = 0, // queries - NVTE_K_Matrix = 1, // keys - NVTE_K_Matrix_Transpose = 2, // keys transposed - NVTE_V_Matrix = 3, // values - NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_Q_Matrix_Transpose = 1, // queries transposed + NVTE_K_Matrix = 2, // keys + NVTE_K_Matrix_Transpose = 3, // keys transposed + NVTE_V_Matrix = 4, // values + NVTE_V_Matrix_Transpose = 5, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output }; +// Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) +struct MXFP8PaddedSizes { + int64_t s_q_padded; + int64_t s_kv_padded; + int64_t s_q_scale; + int64_t s_kv_scale; + int64_t s_q_scale_padded; + int64_t s_kv_scale_padded; + int64_t d_qk_padded; + int64_t d_v_padded; + int64_t d_qk_scale; + int64_t d_v_scale; + int64_t d_qk_scale_padded; + int64_t d_v_scale_padded; +}; + +inline bool is_aligned_modulo(void *ptr, int64_t modulo) { + // Cast the pointer to a large enough integer type (uintptr_t) + uintptr_t address = reinterpret_cast(ptr); + // Check if the address is perfectly divisible by 16 + return (address % modulo) == 0; +} + +// Pad s and d for MXFP8 layout +inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { + constexpr int64_t block_size = 32; + MXFP8PaddedSizes p; + p.s_q_padded = ((s_q + 127) / 128) * 128; + p.s_kv_padded = ((s_kv + 127) / 128) * 128; + p.s_q_scale = (s_q + block_size - 1) / block_size; + p.s_kv_scale = (s_kv + block_size - 1) / block_size; + p.s_q_scale_padded = ((p.s_q_scale + 3) / 4) * 4; + p.s_kv_scale_padded = ((p.s_kv_scale + 3) / 4) * 4; + p.d_qk_padded = ((d_qk + 127) / 128) * 128; + p.d_v_padded = ((d_v + 127) / 128) * 128; + p.d_qk_scale = (d_qk + block_size - 1) / block_size; + p.d_v_scale = (d_v + block_size - 1) / block_size; + p.d_qk_scale_padded = ((p.d_qk_scale + 3) / 4) * 4; + p.d_v_scale_padded = ((p.d_v_scale + 3) / 4) * 4; + return p; +} + +// Get matrix strides for a 4D tensor [batch, head, seqlen, hidden] given a QKV format. +// strideA must point to at least 4 int64_t elements. +inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strides, NVTE_QKV_Format format, + bool transpose) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strides[batch_dim_idx] = s * h * d; + strides[head_dim_idx] = d; + strides[seqlen_dim_idx] = h * d; + strides[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strides[batch_dim_idx] = h * d; + strides[head_dim_idx] = d; + strides[seqlen_dim_idx] = b * h * d; + strides[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strides[batch_dim_idx] = h * s * d; + strides[head_dim_idx] = s * d; + strides[seqlen_dim_idx] = d; + strides[hidden_dim_idx] = 1; + break; + default: + NVTE_CHECK(false, "Invalid format."); + break; + } +} + +// get matrix strides based on matrix type +inline void generateMatrixStrides_v1(int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, + int64_t d_qk, int64_t d_v, int64_t *strides, + NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + bool transpose = (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; + + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; + } + NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, + "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * s_q * d_qk; + strides[head_dim_idx] = s_q * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_qk; + strides[head_dim_idx] = s_kv * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_v; + strides[head_dim_idx] = s_kv * d_v; + strides[seqlen_dim_idx] = d_v; + strides[hidden_dim_idx] = 1; + } + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } + + if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { + strides[seqlen_kv_dim_idx] = 1; + strides[seqlen_q_dim_idx] = s_kv; + strides[head_dim_idx] = s_q * s_kv; + strides[batch_dim_idx] = h * s_q * s_kv; + } +} + void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8169bf22e2..90393ce8c8 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,6 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ + NVTE_BHSD_BHSD_BHSD = 25, /*!< BHSD_BHSD_BHSD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -70,6 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, + /*! SD_SD_SD QKV layouts, e.g. BHSD_BHSD_BHSD */ + NVTE_SD_SD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -90,6 +93,8 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, + /*! BHSD QKV format, e.g. BHSD_BHSD_BHSD */ + NVTE_BHSD = 7, }; /*! \enum NVTE_Bias_Type @@ -188,6 +193,24 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \brief Convert one NVTE_QKV_Format to another. + * + * \param[in] src_format The source format. + * \param[in] src_shape The source shape. + * \param[in] dst_format The destination format. + * \param[in,out] dst_shape The destination shape. + * \param[in,out] b The batch size. + * \param[in,out] h The number of heads. + * \param[in,out] s The sequence length. + * \param[in,out] d The head dimension. + * \param[in,out] t The time dimension. + * + * \return The destination shape. + */ +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t); + /*! \brief Get fused attention backend based on input parameters. * * \param[in] is_training Whether the model is in training mode. @@ -274,6 +297,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -283,19 +307,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -347,6 +369,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. + * \param[in] d_out_format Output gradient's format. + * \param[in] dqkv_layout QKV gradient tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -366,11 +391,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6adba23a8f..96e6803ec5 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -48,7 +48,8 @@ .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ - .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -74,7 +75,8 @@ .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ + .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index aa6c063951..906f3ade45 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,12 +29,14 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorStorage from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, @@ -173,15 +175,34 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - q_fp8, k_fp8, v_fp8 = combine_and_quantize( + assert qkv_layout == "sbhd_sbhd_sbhd", ( + "sbhd_sbhd_sbhd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( - qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype + qkv_layout, + q_fp8, + k_fp8, + v_fp8, + src_nominal_dtype=query_layer.dtype, + des_nominal_dtype=query_layer.dtype, ) + if isinstance(quantizer, MXFP8Quantizer): + assert qkv_layout == "bhsd_bhsd_bhsd", ( + "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: - t_fp8 = quantizer(tensor1) - tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + if quantizer is not None: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) else: tensors = (tensor1, tensor2, tensor3) ctx.quantizer = quantizer @@ -193,16 +214,26 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou def backward(ctx, grad1, grad2, grad3): # pylint: disable=missing-function-docstring if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: - dt_fp8 = ctx.quantizer(grad1) - tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + if ctx.quantizer is not None: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + else: + tensors = grad1, grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] - dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + dq_fp8, dk_fp8, dv_fp8, ctx.qkv_layout = combine_and_quantize( ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer ) tensors = combine_and_dequantize( ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) + if isinstance(ctx.quantizer, MXFP8Quantizer): + assert ctx.qkv_layout == "bhsd_bhsd_bhsd", ( + "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -375,6 +406,7 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) + apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 if "padding" in attn_mask_type and attention_mask is None: attention_mask = dpa_utils.get_padding_mask( @@ -401,9 +433,6 @@ def forward( ) ) - batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] - apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 - # [b, np, sq, sk] output_size = ( query_layer.size(1), @@ -423,11 +452,6 @@ def forward( int(query_layer.shape[2] / value_layer.shape[2]), dim=2 ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) - # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], @@ -442,14 +466,15 @@ def forward( scale /= self.layer_number if fp8: + # get fp8 recipe for DPA + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=S_quantizer.dtype, device="cuda" @@ -457,19 +482,44 @@ def forward( dP_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=dP_quantizer.dtype, device="cuda" ) + # disable swizzle for MXFP8Quantizer + for q in [ + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ]: + if isinstance(q, MXFP8Quantizer): + q.optimize_for_gemm = False + q.internal = False - if "2" in qkv_layout or "3" in qkv_layout: - qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) - qkv_layout = "_".join([qkv_format] * 3) + # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + QKV_quantizer, + "QKV_quantizer", + "sbhd_sbhd_sbhd", ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + dQKV_quantizer, + "dQKV_quantizer", + "sbhd_sbhd_sbhd", ) + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) + # Raw attention scores. [b * np, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( @@ -599,14 +649,14 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(seqlen, batch_size, -1) + context_layer = context_layer.view(max_seqlen_q, batch_size, -1) if q_format == "bshd": # [b, np, sq, hn] --> [b, sq, np, hn] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [b, sq, np, hn] --> [b, sq, hp] - context_layer = context_layer.view(batch_size, seqlen, -1) + context_layer = context_layer.view(batch_size, max_seqlen_q, -1) if q_format == "thd": # [b, np, sq, hn] --> [b, sq, np, hn] @@ -1194,21 +1244,25 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + # save qkv_layout and get output format + original_qkv_layout = qkv_layout + _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) + # input types are inferred from the real data while output types are controlled by fp8_output # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) - # whether bwd kernel in FP8: + # whether fwd kernel will be run in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel will be run in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # get nominal data type for out @@ -1221,12 +1275,15 @@ def forward( fused_attention_backend = FusedAttnBackend["FP8"] # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E4M3 + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; + # dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) # print quantizers print_quantizers( @@ -1244,6 +1301,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, @@ -1266,6 +1324,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1275,21 +1334,34 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - - # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - - if isinstance(out_, Float8Tensor): - if not is_output_fp8 or not is_bwd_fp8: + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if isinstance(out_, QuantizedTensorStorage): + if not is_output_fp8 or bwd_requires_o_f16: out = out_.dequantize().view(out_.shape) else: - if is_output_fp8 or ( - is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): + if is_output_fp8 or bwd_requires_o_fp8: out_fp8 = O_quantizer(out_) # print quantizers @@ -1311,10 +1383,14 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) - else: + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: if is_input_fp8: @@ -1344,6 +1420,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1431,7 +1508,11 @@ def forward( ctx.qkv_layout = qkv_layout else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type @@ -1451,13 +1532,19 @@ def forward( def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - # d_out is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): - d_out = ctx.dO_quantizer(d_out) - if not ctx.use_FAv2_bwd: - d_out._data = d_out._data.contiguous() - elif not ctx.use_FAv2_bwd: + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + d_out_fp8 = None + d_out_format = ctx.o_format + if ctx.fp8: + if ctx.fp8_recipe.mxfp8(): + d_out, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, d_out) + if isinstance(d_out, QuantizedTensorStorage): + d_out_fp8 = d_out + else: + d_out_fp8 = ctx.dO_quantizer(d_out) + if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( q_fp8, @@ -1519,14 +1606,6 @@ def backward(ctx, d_out, *_args): dqkv_nominal_dtype = ctx.nominal_dtype if ctx.fp8: - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - if ctx.is_output_fp8: - d_out_fp8 = d_out - else: - d_out_fp8 = ctx.dO_quantizer(d_out) - # print quantizers print_quantizers( "FusedAttnFunc.backward >> before: ", @@ -1539,27 +1618,26 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # get tex.DType for dq, dk, dv data - dqkv_te_dtype = d_out_fp8._fp8_dtype + # # get tex.DType for dq, dk, dv data + # dqkv_te_dtype = d_out_fp8._fp8_dtype # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # out_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # DelayedScaling: + # out_, dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # - # dq_, dk_, dv_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_ = ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8 - ) + # Float8CurrentScaling: + # out_, dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: + # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_ = out + if ctx.fp8_recipe.mxfp8(): + out_ = out + aux_ctx_tensors.append(d_out) dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1571,7 +1649,7 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, + # dqkv_te_dtype, # could we remove this? aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1583,6 +1661,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1594,8 +1675,8 @@ def backward(ctx, d_out, *_args): # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_float8tensor = isinstance(dq_, Float8Tensor) - if is_float8tensor and not ctx.is_input_fp8: + is_quantized_tensor = isinstance(dq_, QuantizedTensorStorage) + if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( ctx.qkv_layout, @@ -1604,9 +1685,9 @@ def backward(ctx, d_out, *_args): dv_, src_nominal_dtype=dq_.dtype, ) - if not is_float8tensor and ctx.is_input_fp8: + if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv = combine_and_quantize( + dq, dk, dv, _ = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) @@ -1624,7 +1705,7 @@ def backward(ctx, d_out, *_args): else: if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) - dqkv_te_dtype = TE_DType[d_out.dtype] + # dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, @@ -1637,7 +1718,7 @@ def backward(ctx, d_out, *_args): out, d_out, dqkv_nominal_dtype, - dqkv_te_dtype, + # dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1649,6 +1730,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1921,7 +2005,7 @@ def forward( " with FP8!" ) if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + all_quantizers = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) for q in all_quantizers: if isinstance(q, Float8CurrentScalingQuantizer): q.with_amax_reduction = True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index bd6b626b64..7886c625b4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -23,6 +23,8 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.common.recipe import MXFP8BlockScaling, Format from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.constants import ( @@ -59,6 +61,18 @@ _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def get_bsh_dims(tensor_format): + """Get batch dimension and sequence dimension from tensor format""" + if tensor_format in ["bshd", "sbhd", "bhsd"]: + batch_dim = tensor_format.index("b") + seq_dim = tensor_format.index("s") + head_dim = tensor_format.index("h") + else: # tensor_format == "thd" + batch_dim = seq_dim = tensor_format.index("t") + head_dim = tensor_format.index("h") + return batch_dim, seq_dim, head_dim + + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -419,6 +433,7 @@ def flash_attn_a2a_communicate( ), "cu_seqlens_padded is required for THD format!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + batch_dim, _, head_dim = get_bsh_dims(qkv_format) if before_attn: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -430,13 +445,14 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # reorder the sequence chunks x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # or [b, np//cp, cp*2, s//2, hn] -> [b, np//cp, cp*s, hn] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) @@ -452,12 +468,19 @@ def flash_attn_a2a_communicate( x = a2a_inputs[i] # [b, s, np, hn] -> [b, s, cp, np//cp, hn] # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # or [b, np, s, hn] -> [b, cp, np//cp, s, hn] # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + x = x.view( + *x.shape[:head_dim], + cp_size, + x.shape[head_dim] // cp_size, + *x.shape[head_dim + 1 :], + ) # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # or [b, cp, np//cp, s, hn] -> [cp, b, np//cp, s, hn] # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() + a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -467,9 +490,10 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # or [b, np//cp, cp*s, hn] -> [b, np//cp, cp*2, s//2, hn] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -486,10 +510,12 @@ def flash_attn_a2a_communicate( x = a2a_outputs[i - 2] # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # or [cp, 2, b, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + x = x.movedim(0, head_dim + 1).movedim(0, seq_dim + 1).contiguous() # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] # or [t, cp, np//cp, hn] -> [t, np, hn] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) @@ -775,13 +801,16 @@ def cp_p2p_fwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step, O_quantizer_per_step, rank, @@ -867,11 +896,17 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_kv_padded_ = cu_seqlens_kv_padded fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -888,7 +923,8 @@ def cp_p2p_fwd_fused_attn( fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs, @@ -907,7 +943,7 @@ def cp_p2p_fwd_fused_attn( if return_max_logit: return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit - return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None # , new_qkv_layout def cp_p2p_fwd_flash_attn( @@ -1065,15 +1101,21 @@ def cp_p2p_bwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, + d_out_format, + dqkv_layout, attn_mask_type, attn_bias_type, deterministic, fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, S_quantizer, dP_quantizer_per_step, dQKV_quantizer_per_step, + # O_quantizer_per_step, + QKV_quantizer_per_step, + dO_quantizer_per_step, q_part, k_part, v_part, @@ -1123,16 +1165,28 @@ def cp_p2p_bwd_fused_attn( fp8_meta_kwargs = {} if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip( - [q_fp8, kv_fp8, kv_fp8], - [q_part, k_part, v_part], + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step ) - ] - if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): - out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + if not fp8_recipe.mxfp8(): + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + else: + # out_part, o_format = dpa_utils.permute_to_grouped_tensor(o_format, out_part) + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) + # out_part = O_quantizer_per_step(out_part) + aux_tensors.append(dout_part) + dout_part = dO_quantizer_per_step(dout_part) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step @@ -1148,7 +1202,7 @@ def cp_p2p_bwd_fused_attn( out_part, dout_part, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, aux_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -1156,6 +1210,9 @@ def cp_p2p_bwd_fused_attn( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, + d_out_format=d_out_format, + dqkv_layout=dqkv_layout, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, deterministic=deterministic, @@ -1367,7 +1424,7 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) q_f16 = None q_fp8, k_fp8, v_fp8 = (None, None, None) @@ -1376,12 +1433,13 @@ def forward( if fp8 and is_input_fp8: QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = (q._data, k._data, v._data) + if not fp8_recipe.mxfp8(): + q, k, v = (q._data, k._data, v._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) - if fp8 and is_input_fp8: + if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) @@ -1397,13 +1455,15 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - else: + elif not fp8_recipe.mxfp8(): # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers @@ -1424,10 +1484,11 @@ def forward( # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + S_quantizer_per_step[i] = S_quantizer.copy() if S_quantizer is not None else None O_quantizer_per_step[i] = O_quantizer.copy() - O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not fp8_recipe.mxfp8(): + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype @@ -1553,6 +1614,7 @@ def forward( k_shape = k.shape k_numel = k.numel() v_shape = v.shape + o_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] @@ -1560,6 +1622,7 @@ def forward( # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + o_format = qkv_format for i in range(cp_size + 1): if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): @@ -1613,13 +1676,16 @@ def forward( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step[i], O_quantizer_per_step[i], rank, @@ -1671,6 +1737,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1698,6 +1765,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1725,6 +1793,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1753,6 +1822,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1780,7 +1850,7 @@ def forward( out_per_step[i - 1] = out_per_step[i - 1].dequantize( dtype=torch.float32 ) - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) if i == 1: @@ -1788,7 +1858,7 @@ def forward( if qkv_format == "thd": if enable_mla: out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape + o_shape ) else: # MHA or GQA @@ -1834,7 +1904,7 @@ def forward( # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: if i == 0: out = flash_attn_fwd_out_correction_init( out_per_step[0], @@ -1843,7 +1913,7 @@ def forward( seq_dim, ) if enable_mla: - out = out.view(v_shape) + out = out.view(o_shape) else: out = out.view(q.shape) else: @@ -1854,7 +1924,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1865,7 +1935,7 @@ def forward( softmax_lse_in_packed_format, ) else: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: flash_attn_fwd_second_half_out_correction( out, out_per_step[i], @@ -1873,7 +1943,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1884,12 +1954,13 @@ def forward( softmax_lse_in_packed_format, ) - if qkv_format == "bshd": + if o_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] - elif qkv_format == "sbhd": + elif o_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] + out_part = out.to(fwd_nominal_dtype) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) @@ -1897,10 +1968,10 @@ def forward( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) if use_fused_attention: - if qkv_format == "bshd": + if o_format == "bshd": # [b*s, h, d] -> [b, s, h, d] out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) if return_max_logit: @@ -1911,7 +1982,7 @@ def forward( out = out.view(-1, *out.shape[-2:]) # update FP8 quantizers: amax across cp_size steps - if fp8 and use_fused_attention: + if fp8 and use_fused_attention and not fp8_recipe.mxfp8(): amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) O_quantizer.amax.copy_(amax_cp_fwd[1]) @@ -1934,7 +2005,11 @@ def forward( out_f16 = out.to(fwd_nominal_dtype) if fp8 and ( is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() + ) ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 @@ -1945,7 +2020,7 @@ def forward( kv_fp8 = None kv = p2p_comm_buffers[-1] - if fp8: + if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8], [q, kv]) @@ -1953,17 +2028,28 @@ def forward( # q, kv, out fp8_tensors = (None, None, None) f16_tensors = (None, None, None) + out_f16 = out_part if ctx.fp8: # fwd: fp8, bwd: fp8, save all fp8 fp8_tensors = (q_fp8, kv_fp8, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: f16_tensors = (None, None, out_f16) - elif fp8 and is_input_fp8: + elif fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) + elif fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): # fwd: fp8, bwd: f16, save all f16 # dequantize fp8 inputs q_f16 = q_fp8.dequantize() kv_f16 = kv_fp8.dequantize() f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and is_input_fp8 and fp8_recipe.mxfp8(): + # fwd: fp8, bwd: f16, save all f16 + # there is already an F16 version of the inputs + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q, k, v) + kv = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) + f16_tensors = (q_f16, kv, out_f16) + elif fp8 and not is_input_fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) elif fp8: # fwd: fp8, bwd: f16, save all f16 # inputs are already in f16 @@ -2018,6 +2104,7 @@ def forward( ctx.k_numel = k_numel ctx.k_shape = k_shape ctx.v_shape = v_shape + ctx.o_shape = o_shape ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer @@ -2028,11 +2115,12 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop(f"{nvtx_label}") @@ -2050,7 +2138,12 @@ def backward(ctx, dout, *_args): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): + if ( + ctx.fp8 + and ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2091,6 +2184,7 @@ def backward(ctx, dout, *_args): causal = "causal" in ctx.attn_mask_type seq_dim = None qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + o_format = ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") @@ -2150,28 +2244,33 @@ def backward(ctx, dout, *_args): buffer_dtype = torch.uint8 dq_buffer = None dout_fp8 = None - bwd_output_te_dtype = None + # bwd_output_te_dtype = None dkv_buffer = None + d_out_format = o_format if ctx.fp8: assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = FusedAttnBackend["FP8"] - q, kv, out = ( - q_fp8._data, - kv_fp8._data, - ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8._data - ), - ) + if not ctx.fp8_recipe.mxfp8(): + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype # dout: torch.Tensor, dtype=torch.uint8 - if ctx.is_output_fp8: + # if ctx.fp8_recipe.mxfp8(): + # dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout - else: + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) - dout = dout_fp8._data + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data # print quantizers print_quantizers( @@ -2186,7 +2285,7 @@ def backward(ctx, dout, *_args): ) # dout_fp8._fp8_dtype - bwd_output_te_dtype = ctx.dO_quantizer.dtype + # bwd_output_te_dtype = ctx.dO_quantizer.dtype # create buffers for reduction in float32 if ctx.fp8_recipe.delayed(): @@ -2195,7 +2294,7 @@ def backward(ctx, dout, *_args): dtype=buffer_dtype, device=q.device, ) - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_buffer = torch.empty( q.shape, dtype=torch.float32, @@ -2209,7 +2308,7 @@ def backward(ctx, dout, *_args): ) dkv_recv_buffer = torch.empty_like(dkv_send_buffer) p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dkv_buffer = torch.zeros( kv.shape, dtype=torch.float32, @@ -2222,10 +2321,13 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dP_quantizer_per_step[i] = ( + ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + ) dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() - dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not ctx.fp8_recipe.mxfp8(): + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) @@ -2236,19 +2338,19 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] + # bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) - out, dout = flash_attn_a2a_communicate( - [out, dout], + dout = flash_attn_a2a_communicate( + [dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, @@ -2258,8 +2360,8 @@ def backward(ctx, dout, *_args): ) if ctx.enable_mla: - out = out.view(*ctx.v_shape) - dout = dout.view(*ctx.v_shape) + out = out.view(*ctx.o_shape) + dout = dout.view(*ctx.o_shape) else: # MHA or GQA out = out.view(*q.shape) @@ -2360,10 +2462,11 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or ctx.fp8_recipe.mxfp8() else out_fp8 ), - dout_fp8, + dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, softmax_lse, softmax_lse_, rng_states, @@ -2381,15 +2484,21 @@ def backward(ctx, dout, *_args): ctx.softmax_scale, ctx.dropout_p, qkv_layout, + ctx.qkv_format, + ctx.qkv_format, + qkv_layout, ctx.attn_mask_type, ctx.attn_bias_type, ctx.deterministic, ctx.fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, ctx.S_quantizer, dP_quantizer_per_step[i], dQKV_quantizer_per_step[i], + # ctx.O_quantizer, + ctx.QKV_quantizer, + ctx.dO_quantizer, ] else: flash_attn_inputs = [ @@ -2463,7 +2572,7 @@ def backward(ctx, dout, *_args): if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8_recipe.delayed(): dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] # copy dq_ into the right buffer position @@ -2547,7 +2656,7 @@ def backward(ctx, dout, *_args): # dkv correction if ctx.fp8 and ctx.fp8_recipe.delayed(): dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] - elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + elif ctx.fp8 and (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()): dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] @@ -2637,9 +2746,10 @@ def backward(ctx, dout, *_args): # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + if not ctx.fp8_recipe.mxfp8(): + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) dq = dq_buffer if ctx.fp8_recipe.delayed(): @@ -2662,7 +2772,7 @@ def backward(ctx, dout, *_args): ) dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dk = dkv[: ctx.k_numel].view(ctx.k_shape) dv = dkv[ctx.k_numel :].view(ctx.v_shape) @@ -2678,7 +2788,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers @@ -2696,7 +2806,8 @@ def backward(ctx, dout, *_args): if cp_size_a2a > 1: if ctx.fp8 and ctx.is_input_fp8: dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv - dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) + if not ctx.fp8_recipe.mxfp8(): + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2707,7 +2818,7 @@ def backward(ctx, dout, *_args): ctx.cp_stream, False, ) - if ctx.fp8 and ctx.is_input_fp8: + if ctx.fp8 and ctx.is_input_fp8 and not ctx.fp8_recipe.mxfp8(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) @@ -2813,6 +2924,10 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") @@ -2822,7 +2937,8 @@ def forward( cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) - qkv_dtype = q.dtype + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -2866,9 +2982,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - seq_dim = qkv_format.index("s") assert ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 @@ -2883,10 +2996,55 @@ def forward( else: cu_seqlens_q_padded = None + # FP8 setup + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + ( + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + fwd_nominal_dtype = q.dtype + fp8_meta_kwargs = {} + q_fp8, k_fp8, v_fp8 = (None, None, None) + q_f16, k_f16, v_f16 = (None, None, None) + fused_attn_backend = None + if fp8: + assert use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + if is_input_fp8: + q_fp8, k_fp8, v_fp8 = q, k, v + elif not fp8_recipe.mxfp8(): + q_f16, k_f16, v_f16 = q, k, v + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + else: + q_f16, k_f16, v_f16 = q, k, v + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer + elif use_fused_attention: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + q_shape = q.shape # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + k_shape = k.shape + v_shape = v.shape # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) @@ -2913,7 +3071,9 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) + enable_mla = k.shape[-1] != v.shape[-1] + out_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] + out = torch.empty(out_shape, dtype=fwd_nominal_dtype, device=q.device) max_logit_per_step = [None, None] max_logit = None @@ -2946,9 +3106,26 @@ def forward( # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: + q_part, k_part, v_part = q_, k_, v_ + new_qkv_layout = qkv_layout + if fp8: + if not fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like( + q_fp8, data=q_, dtype=fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_, dtype=fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_, dtype=fwd_nominal_dtype + ) + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) ( out_per_step[i], - [softmax_lse_per_step[i], rng_states[i]], + aux_ctx_tensors, *max_logit_, ) = fused_attn_fwd( is_training, @@ -2956,14 +3133,15 @@ def forward( max_seqlen_kv_, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - qkv_dtype, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + q_part, + k_part, + v_part, + fwd_nominal_dtype, + fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -2972,9 +3150,17 @@ def forward( window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors if return_max_logit: max_logit_per_step[i] = max_logit_[0] + if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): + out_per_step[i] = out_per_step[i].dequantize(dtype=fwd_nominal_dtype) else: fa_forward_args_thd = get_fa_args( True, @@ -3034,10 +3220,42 @@ def forward( else: out = out.view(-1, *out.shape[-2:]) - ctx.save_for_backward( - q, - k, - v, + out_fp8 = None + out_ret = out + if fp8 and (is_output_fp8 or (is_bwd_fp8 and fp8_recipe.delayed())): + out_fp8 = O_quantizer(out) + out_ret = out_fp8 + ctx.fp8 = fp8 and is_bwd_fp8 + ctx.fp8_recipe = fp8_recipe + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + ctx.qkv_reshaped = True + if ctx.fp8: + q_fp8_save, k_fp8_save, v_fp8_save = None, None, None + if fp8_recipe.delayed() or fp8_recipe.float8_current_scaling(): + q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) + k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) + v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) + if fp8_recipe.delayed(): + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) + if fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) + f16_tensors = (None, None, None, out) + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out) + elif fp8: + if is_input_fp8: + q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + f16_tensors = (q_f16, k_f16, v_f16, out) + ctx.qkv_reshaped = False + else: + f16_tensors = (q, k, v, out) + + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_q_padded, *cu_seqlens_kv_per_step, @@ -3045,8 +3263,14 @@ def forward( *softmax_lse_per_step, *rng_states, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects - ctx.qkv_dtype = qkv_dtype + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.q_shape = q_shape + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.out_shape = out_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group @@ -3060,10 +3284,28 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.O_quantizer = O_quantizer.copy() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: - return out, max_logit - return out + return out_ret, max_logit + return out_ret @staticmethod def backward(ctx, dout, *_args): @@ -3072,22 +3314,64 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] - cu_seqlens_kv_per_step = saved_tensors[5:7] - out_per_step = saved_tensors[7:9] - softmax_lse_per_step = saved_tensors[9:11] - rng_states = saved_tensors[11:13] + cu_seqlens_kv_per_step = [None, None] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] + ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_q_padded, + cu_seqlens_kv_per_step[0], + cu_seqlens_kv_per_step[1], + out_per_step[0], + out_per_step[1], + softmax_lse_per_step[0], + softmax_lse_per_step[1], + rng_states[0], + rng_states[1], + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - dout = dout.view(q.shape) - dq = torch.empty_like(q) - dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) - dv = torch.zeros_like(dk) + dout = dout.view(ctx.out_shape) + dout_fp8 = None + if ctx.fp8: + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout_fp8 = ctx.dO_quantizer(dout) + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data + + if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): + q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if not ctx.qkv_reshaped: + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + + dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + dk = torch.zeros( + (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=k.device, + ) + dv = torch.zeros( + (ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=v.device, + ) dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3162,32 +3446,85 @@ def backward(ctx, dout, *_args): out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + aux_ctx_tensors = [ + softmax_lse_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + ] + q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout + d_out_format = ctx.qkv_format + if ctx.fp8: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + if not ctx.fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like( + q_fp8, data=q_, dtype=ctx.fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_, dtype=ctx.fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_, dtype=ctx.fwd_nominal_dtype + ) + if not ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ): + out_part = ctx.O_quantizer(out_part) + dout_part = Float8Tensor.make_like( + dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype + ) + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer + ) + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( + d_out_format, dout_part + ) + aux_ctx_tensors.append(dout_part) + dout_part = ctx.dO_quantizer(dout_part) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - out_, - dout_, - ctx.qkv_dtype, - TE_DType[dout.dtype], + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.fwd_nominal_dtype, + # TE_DType[dout.dtype], aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=ctx.qkv_format, + d_out_format=d_out_format, + dqkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + if ctx.fp8: + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + ( + x.dequantize(dtype=ctx.fwd_nominal_dtype) + if isinstance(x, QuantizedTensorStorage) + else x + ) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ] else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] @@ -3265,6 +3602,10 @@ def backward(ctx, dout, *_args): dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) dk = dk.movedim(0, seq_dim).contiguous() dv = dv.movedim(0, seq_dim).contiguous() + + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( @@ -3289,6 +3630,10 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, + None, + None, ) @@ -3394,21 +3739,19 @@ def forward( ), "The number of attention heads needs to be divisible by CP size!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - if qkv_format in ["bshd", "sbhd"]: - batch_dim = qkv_format.index("b") - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - batch_dim = seq_dim = qkv_format.index("t") + original_qkv_layout = qkv_layout + o_format = qkv_format + batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_o, _ = get_bsh_dims(o_format) assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; @@ -3421,7 +3764,7 @@ def forward( max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) q_fp8, k_fp8, v_fp8 = (None, None, None) @@ -3430,10 +3773,16 @@ def forward( fused_attn_backend = FusedAttnBackend["FP8"] if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + elif not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + # else: + # q, k, v = [q_fp8, k_fp8, v_fp8] + # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) + # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["o_quantizer"] = O_quantizer @@ -3448,7 +3797,7 @@ def forward( q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, - seq_dim, + seq_dim_qkv, cp_size, cp_group, cp_stream, @@ -3463,15 +3812,20 @@ def forward( out_fp8 = None out_f16 = None - batch_size = q.shape[batch_dim] + batch_size = q.shape[batch_dim_qkv] q_part, k_part, v_part = q, k, v out_part = None if use_fused_attention: - if fp8: + if fp8 and not fp8_recipe.mxfp8(): q_part, k_part, v_part = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] + if fp8 and fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) + q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3486,6 +3840,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3498,7 +3853,7 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if isinstance(out_, Float8Tensor): + if isinstance(out_, QuantizedTensorStorage): out_fp8 = out_ out_ = out_._data if is_bwd_fp8 and not ( @@ -3514,6 +3869,7 @@ def forward( fp8 and is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() ): out_part = O_quantizer(out_) else: @@ -3547,12 +3903,12 @@ def forward( out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, - seq_dim, + seq_dim_o, cp_size, cp_group, cp_stream, before_attn=False, - qkv_format=qkv_format, + qkv_format=o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) if return_max_logit: @@ -3561,15 +3917,15 @@ def forward( ) if use_fused_attention: - if qkv_format == "bshd": + if o_format == "bshd": # [b*s, h, d] -> [b, s, h, d] out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out_ = out_.view(-1, batch_size, *out_.shape[-2:]) if fp8 and use_fused_attention: - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_f16 = out_ if is_output_fp8: out_fp8 = O_quantizer(out_) @@ -3583,17 +3939,23 @@ def forward( out_ret = out_fp8 if is_output_fp8 else out_f16 ctx.fp8 = fp8 and is_bwd_fp8 + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) if ctx.fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): fp8_tensors = (q_part, k_part, v_part, None) f16_tensors = (None, None, None, out_part) else: fp8_tensors = (q_part, k_part, v_part, out_part) - elif fp8: + elif fp8 and not fp8_recipe.mxfp8(): q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) f16_tensors = (q_part, k_part, v_part, out_part) + elif fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_part) + ctx.qkv_layout = original_qkv_layout else: f16_tensors = (q_part, k_part, v_part, out_part) @@ -3617,7 +3979,7 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format + # ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.deterministic = deterministic @@ -3639,11 +4001,13 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if return_max_logit: return out_ret, max_logit @@ -3671,27 +4035,28 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_format = ctx.qkv_format - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + # qkv_format = ctx.qkv_format + # qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + # qkv_layout = ctx.qkv_layout causal = "causal" in ctx.attn_mask_type + dqkv_format, _, _ = dpa_utils.get_qkv_format(ctx.dqkv_layout) - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - seq_dim = qkv_format.index("t") + batch_dim_dqkv, seq_dim_dqkv, _ = get_bsh_dims(dqkv_format) + _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype - dqkv_te_dtype = None + # dqkv_te_dtype = None fused_attn_backend = None dout_fp8 = dout if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorStorage): + if not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dout = ctx.dO_quantizer(dout) dout_fp8 = dout - dqkv_te_dtype = dout._fp8_dtype - dout = dout._data + if not ctx.fp8_recipe.mxfp8(): + # dqkv_te_dtype = dout._fp8_dtype + dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer @@ -3704,11 +4069,11 @@ def backward(ctx, dout, *_args): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} - dqkv_te_dtype = TE_DType[dout.dtype] + # dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: - if qkv_format in ["bshd", "sbhd"]: + if ctx.o_format in ["bshd", "sbhd"]: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) else: @@ -3718,15 +4083,14 @@ def backward(ctx, dout, *_args): dout = flash_attn_a2a_communicate( dout, chunk_ids_for_a2a, - seq_dim, + seq_dim_do, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=True, - qkv_format=qkv_format, + qkv_format=ctx.o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3741,7 +4105,7 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: - if qkv_format == "thd": + if ctx.o_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( _flash_attn_varlen_bwd, ) @@ -3767,13 +4131,21 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["softcap"] = 0.0 dq_fp8, dk_fp8, dv_fp8 = None, None, None + d_out_format = ctx.o_format if ctx.use_fused_attention: q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or ctx.fp8_recipe.mxfp8(): out_part = out - dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + if not ctx.fp8_recipe.mxfp8(): + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + else: + dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + dout_part = ctx.dO_quantizer(dout) + aux_ctx_tensors.append(dout) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3785,14 +4157,17 @@ def backward(ctx, dout, *_args): out_part, dout_part, bwd_nominal_dtype, - dqkv_te_dtype, + # dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + d_out_format=d_out_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, @@ -3801,7 +4176,7 @@ def backward(ctx, dout, *_args): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if isinstance(dq, Float8Tensor): + if isinstance(dq, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: @@ -3810,7 +4185,7 @@ def backward(ctx, dout, *_args): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - qkv_format, + ctx.o_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -3837,18 +4212,17 @@ def backward(ctx, dout, *_args): dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, - seq_dim, + seq_dim_dqkv, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=False, - qkv_format=qkv_format, + qkv_format=dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - - if qkv_format == "bshd": + if dqkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif qkv_format == "sbhd": + elif dqkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] d_bias = None @@ -3863,8 +4237,12 @@ def backward(ctx, dout, *_args): ) if ctx.fp8: - if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ( + ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() + ) and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize( + ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer + ) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) @@ -3872,13 +4250,12 @@ def backward(ctx, dout, *_args): ] if not ctx.is_input_fp8: dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.dqkv_layout, dq, dk, dv, src_nominal_dtype=bwd_nominal_dtype, ) - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( @@ -4009,7 +4386,6 @@ def attn_forward_func_with_cp( in Megatron-LM. """ - if cp_comm_type == "a2a+p2p": assert ( isinstance(cp_group, list) and len(cp_group) == 2 @@ -4056,10 +4432,11 @@ def attn_forward_func_with_cp( ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] - assert not enable_mla or cp_comm_type in [ - "p2p", - "a2a+p2p", - ], f"Context parallelism does not support MLA with {cp_comm_type=}!" + # assert not enable_mla or cp_comm_type in [ + # "p2p", + # "a2a+p2p", + # "a2a", + # ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: @@ -4117,7 +4494,16 @@ def attn_forward_func_with_cp( elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [ + window_size, + cp_group, + cp_stream, + use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, + ] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..ba24aa658e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -98,19 +98,19 @@ +-------------------+-----------+-----------------------------------------------------------------------------------+ | Linear | Attention | Configuration | +===================+===========+===================================================================================+ -| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | -| | | export NVTE_DPA_FP8_RECIPE="F16" | +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS, NVFP4 or MXFP8 to autocast(); | +| /MXFP8 | | export NVTE_DPA_FP8_RECIPE="F16" | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8DS | Pass FP8DS to autocast(); | +| FP8DS | FP8DS | Pass FP8DS to autocast(); | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8DS | Pass FP8CS to autocast(); | +| FP8CS | FP8DS | Pass FP8CS to autocast(); | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | @@ -118,19 +118,19 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8CS | Pass FP8DS to autocast(); | +| FP8DS | FP8CS | Pass FP8DS to autocast(); | | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8CS | Pass FP8CS to autocast(); | +| FP8CS | FP8CS | Pass FP8CS to autocast(); | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | @@ -139,6 +139,11 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | MXFP8 | Pass NVFP4 to autocast(); | +| | | Attention MXFP8 reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ """ _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} @@ -673,11 +678,15 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False - if not fp8_recipe_dpa.float8_per_tensor_scaling(): - assert not ( - fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha - ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..84f676539b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,11 +35,15 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -471,6 +475,9 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -478,6 +485,10 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False + if use_flash_attention_3 and fp8_recipe.mxfp8(): + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for MXFP8") + use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() @@ -485,9 +496,6 @@ def get_attention_backend( if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") @@ -507,6 +515,17 @@ def get_attention_backend( " with cuDNN < 9.18.0" ) use_fused_attention = False + if use_fused_attention and fp8_recipe.mxfp8(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for MXFP8 on arch < sm100") + use_fused_attention = False + else: + if cudnn_version < (9, 21, 0): + logger.debug("Disabling FusedAttention for MXFP8 with cuDNN < 9.21.0") + use_fused_attention = False + elif qkv_format == "thd": + logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: @@ -599,9 +618,9 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - if use_flash_attention_2 and FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 as it does not support MLA.") - use_flash_attention_2 = False + # if use_flash_attention_2 and FlashAttentionUtils.is_installed: + # logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + # use_flash_attention_2 = False qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": @@ -816,10 +835,55 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: + elif fp8 and qkv_format == "thd": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" - " MLA attention" + " attention and THD format" + ) + use_fused_attention = False + elif fp8 and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and bias" + ) + use_fused_attention = False + + elif core_attention_bias_type != "no_bias" and cp_comm_type in [ + "all_gather", + "a2a", + "a2a+p2p", + ]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias" + " and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD" + " format and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif ( + window_size is not None + and (window_size != (-1, 0) or window_size != (-1, -1)) + and cp_comm_type in ["p2p", "a2a+p2p"] + ): + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with sliding" + " window attention and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif cp_comm_type in ["a2a", "a2a+p2p"] and (num_heads % 2 != 0 or num_gqa_groups % 2 != 0): + logger.debug( + "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" + " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)", + cp_comm_type, + num_heads, + num_gqa_groups, ) use_fused_attention = False @@ -872,12 +936,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention for FP8" - ) - use_fused_attention = False - elif attention_dropout != 0.0: + if attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " "without dropout" @@ -1013,7 +1072,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_fused_attention and window_size is not None and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( "Disabling FusedAttention as only sub-backend %s does not support " @@ -2095,28 +2154,45 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers): +def get_attention_quantizers(fp8, fp8_recipe, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + QKV_quantizer.internal = False QKV_quantizer.set_usage(rowwise=True, columnwise=False) - O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.internal = False + O_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.internal = False dO_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + dP_quantizer.set_usage(rowwise=True, columnwise=False) + + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8_recipe.mxfp8(): + QKV_quantizer.columnwise_usage = True + QKV_quantizer.optimize_for_gemm = True + S_quantizer = None + O_quantizer.columnwise_usage = True + + dO_quantizer.columnwise_usage = True + dO_quantizer.optimize_for_gemm = True + dP_quantizer = None + dQKV_quantizer.columnwise_usage = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2170,18 +2246,74 @@ def print_quantizers( type_str = "DS" elif isinstance(q, Float8CurrentScalingQuantizer): type_str = "CS" - print( - f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" - f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" - ) + elif isinstance(q, MXFP8Quantizer): + type_str = "MXFP8" + if type_str in ["DS", "CS"]: + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) + else: + print(f"{label} >> {names[i]:14s}: {type_str}") + + +def permute_to_grouped_tensor(src_format, tensor): + """Permute tensor to bhsd or htd format for grouped quantization in MXFP8BlockScaling. src_format ={bshd, sbhd, thd}""" + if src_format in ["bhsd", "htd"]: + return tensor, src_format + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") + dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] + perm = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] + tensor = tensor.permute(*perm).contiguous() + return tensor, "bhsd" if src_format != "thd" else "htd" def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype + if isinstance(qkv_quantizer, MXFP8Quantizer): + # bs3hd, sb3hd, etc -> bshd_bshd_bhsd -> bhsd_bhsd_bhsd + # t3hd, etc -> thd_thd_thd -> htd_htd_htd + if q_format not in ["bhsd", "htd"]: + q, _ = permute_to_grouped_tensor(q_format, q) + if kv_format not in ["bhsd", "htd"]: + k, _ = permute_to_grouped_tensor(kv_format, k) + v, _ = permute_to_grouped_tensor(kv_format, v) + qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" + + original_shapes = [x.shape for x in [q, k, v]] + s_q, d_qk = q.shape[-2:] + s_kv, d_v = v.shape[-2:] + assert s_q % 128 == 0 + assert s_kv % 128 == 0 + assert d_qk % 32 == 0 + assert d_v % 32 == 0 + # need to check seqlens in THD % 128 == 0 + q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] + + # consider bhsd for now + if d_qk == d_v: + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=[q, k, v], quantizer=qkv_quantizer + ) + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + else: + # grouped_tensor = GroupedTensor.create_and_quantize( + # tensors=[q, k], quantizer=qkv_quantizer + # ) + # q_fp8, k_fp8 = grouped_tensor.quantized_tensors + q_fp8 = qkv_quantizer(q) + k_fp8 = qkv_quantizer(k) + v_fp8 = qkv_quantizer(v) + q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] + + return q_fp8, k_fp8, v_fp8, qkv_layout + match qkv_group: case 1: dim = qkv_layout.find("3") @@ -2221,7 +2353,7 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): for x in [q_data, k_data, v_data] ] - return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout def combine_and_dequantize( @@ -2230,14 +2362,19 @@ def combine_and_dequantize( """Combine q,k,v based on qkv_layout and dequantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) - if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + if all(isinstance(x, QuantizedTensor) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" if des_nominal_dtype is None: des_nominal_dtype = src_nominal_dtype + if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): + q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] + return q, k, v + q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] match qkv_group: case 1: diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..0a276bdc8a 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -784,14 +784,22 @@ def forward( fp8_dpa = fp8_recipe.fp8_dpa fp8_mha = fp8_recipe.fp8_mha float8_current_scaling = fp8_recipe.float8_current_scaling() + mxfp8_scaling = fp8_recipe.mxfp8() else: fp8_dpa = _dpa_fp8_recipe_dpa fp8_mha = _dpa_fp8_recipe_mha float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" - # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling - # DPA: always produce FP8 output when fp8=True to take advantage of the O amax - dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling or MXFP8BlockScaling recipe + qkv_fp8_output = ( + fp8 + and fp8_mha + and rotary_pos_emb is None + and not float8_current_scaling + and not mxfp8_scaling + ) + # DPA: produce FP8 output when fp8=True to take advantage of the O amax except for MXFP8BlockScaling + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..7a756ead1c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -41,6 +41,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, + "bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -69,6 +70,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, + "bhsd_bhsd_bhsd": NVTE_QKV_Layout.NVTE_BHSD_BHSD_BHSD, } AttnBiasType = { @@ -133,6 +135,7 @@ def fused_attn_fwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -202,6 +205,8 @@ def fused_attn_fwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -292,13 +297,6 @@ def fused_attn_fwd( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." - assert ( - o_quantizer is not None - ), "o_quantizer is required as an input for FP8 fused attention." else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -312,6 +310,7 @@ def fused_attn_fwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -365,7 +364,7 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, fake_dtype: torch.dtype, - dqkv_dtype: tex.DType, + # dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, @@ -377,6 +376,9 @@ def fused_attn_bwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", + d_out_format: str = "sbhd", + dqkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -415,8 +417,8 @@ def fused_attn_bwd( fake_dtype : tex.DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype - dqkv_dtype : tex.DType - data type of dQ, dK and dV; in tex.DType, not torch.dtype + # dqkv_dtype : tex.DType + # data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -443,6 +445,15 @@ def fused_attn_bwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} + d_out_format : str, default = "sbhd" + format of dO; {"sbhd", "bshd", "thd"} + dqkv_layout : str, default = "sbh3d" + layout of dQ, dK and dV; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -497,17 +508,11 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: + # assert ( + # dqkv_dtype is not None + # ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention backward." - assert ( - dp_quantizer is not None - ), "dp_quantizer is required as an input for FP8 fused attention backward." - assert ( - dqkv_dtype is not None - ), "dqkv_dtype is required as an input for FP8 fused attention backward." - assert ( - len(aux_ctx_tensors) == 3 + len(aux_ctx_tensors) >= 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." output_tensors = tex.fused_attn_bwd( @@ -517,6 +522,9 @@ def fused_attn_bwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[d_out_format], + QKVLayout[dqkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -531,7 +539,7 @@ def fused_attn_bwd( o, d_o, fake_dtype, - dqkv_dtype, + # dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 6aab9938b3..f757bbdfee 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -300,6 +300,14 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data = std::nullopt); + std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b2b0751b04..95c985062a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -84,7 +84,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -98,11 +98,13 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c33..192a774ca0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -91,6 +91,22 @@ std::pair quantizer_helper(py::handle quantizer, !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // MXFP8 + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK( + !data.has_value(), + "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + } } return {std::move(te_T), std::move(py_T)}; } @@ -98,7 +114,7 @@ std::pair quantizer_helper(py::handle quantizer, // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -134,8 +150,12 @@ std::vector fused_attn_fwd( std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; + o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; + size_t b = 0, h = 0, s = 0, d = 0, t = 0; + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, + &d, &t); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -146,9 +166,7 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { @@ -156,7 +174,7 @@ std::vector fused_attn_fwd( } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (o_format == NVTE_QKV_Format::NVTE_THD) { te_O.zero_(at::cuda::getCurrentCUDAStream()); } } else { @@ -235,9 +253,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -295,9 +313,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory @@ -310,11 +328,13 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -343,25 +363,35 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; + std::vector dQ_shape(4), dK_shape(4), dV_shape(4); + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), + dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), + dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), + dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - std::vector tmp_shape; + DType dqkv_type = fake_dtype_te; + if (!dqkv_quantizer.is_none()) { + dqkv_type = dqkv_quantizer.attr("dtype").cast(); + } auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); + std::vector tmp_shape; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -378,7 +408,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -392,9 +422,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -407,9 +437,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -420,11 +450,12 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); dK = torch::empty(tmp_shape, options); - tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + tmp_shape = std::vector(dV_shape.begin(), dV_shape.end()); dV = torch::empty(tmp_shape, options); break; default: @@ -438,7 +469,7 @@ std::vector fused_attn_bwd( // construct NVTE tensors if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD)) { if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -451,7 +482,7 @@ std::vector fused_attn_bwd( } } } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); @@ -538,9 +569,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -555,9 +586,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index e715d8f5ba..b44640d006 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1230,6 +1230,18 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data) { + at::Tensor amax_tensor = + at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); + out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index eda5e8fc54..48a7035496 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -206,11 +206,13 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { py::handle quantizer = py::none(); DType quantizer_dtype = DType::kNumTypes; NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + bool with_gemm_swizzled_scales = false; if (!tensor.attr("quantizer").is_none()) { quantizer = tensor.attr("quantizer"); if (!quantizer.is_none()) { scaling_mode = ScalingModeFromQuantizer(quantizer); quantizer_dtype = quantizer.attr("dtype").cast(); + with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); } } auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); @@ -282,6 +284,8 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { getTensorShape(tensor_offsets)); } + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + return ret; }