From e0ae1074204267560237aab4407e4b8b7373da4c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 31 Jan 2026 13:42:45 -0800 Subject: [PATCH 01/59] initial implementation for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- tests/pytorch/attention/test_attention.py | 29 ++- .../common/fused_attn/fused_attn_fp8.cu | 191 ++++++++++++++---- .../dot_product_attention/backends.py | 2 + .../dot_product_attention.py | 8 +- .../attention/dot_product_attention/utils.py | 45 ++++- .../pytorch/cpp_extensions/fused_attn.py | 8 +- transformer_engine/pytorch/csrc/common.h | 8 + .../pytorch/csrc/extensions/attention.cpp | 20 ++ transformer_engine/pytorch/csrc/quantizer.cpp | 12 ++ .../pytorch/tensor/mxfp8_tensor.py | 1 + 11 files changed, 264 insertions(+), 62 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index b372d39879..209a25fe89 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2 +Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bd0ac41974..dd133c840a 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2062,7 +2062,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -2095,6 +2095,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8_format=recipe.Format.HYBRID, fp8_dpa=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -2107,6 +2113,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + print(f"flash_attn_supported: {flash_attn_supported}, fused_attn_supported: {fused_attn_supported}, unfused_attn_supported: {unfused_attn_supported}") if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -2133,21 +2140,22 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - if unfused_attn_supported: - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - os.environ["NVTE_UNFUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") - unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training, fp8_recipe - ) + # if unfused_attn_supported: + # os.environ["NVTE_FLASH_ATTN"] = "0" + # os.environ["NVTE_FUSED_ATTN"] = "0" + # os.environ["NVTE_UNFUSED_ATTN"] = "1" + # _attention_backends["backend_selection_requires_update"] = True + # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + # unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + # dtype, config, True, qkv_layout, is_training, fp8_recipe + # ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") + print(f"Running fused attention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) @@ -2158,6 +2166,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + print(f"Running fused attention with fp8_dpa = False") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( dtype, config, False, qkv_layout, is_training, fp8_recipe ) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f886ec77f4..400a11af6a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1677,9 +1677,31 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + // NVTE_CHECK(is_current_scaling || is_delayed_scaling, + // "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + // "kFloat8E5M2!"); + is_current_scaling = false; + is_delayed_scaling = false; + bool is_mxfp8 = true; + printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); + printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); + printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); + printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); + printf(">>>>>> cudnn_frontend::DataType_t::UINT8: %d\n", cudnn_frontend::DataType_t::UINT8); + printf(">>>>>> cudnn_frontend::DataType_t::INT8: %d\n", cudnn_frontend::DataType_t::INT8); + printf(">>>>>> cudnn_frontend::DataType_t::HALF: %d\n", cudnn_frontend::DataType_t::HALF); + printf(">>>>>> cudnn_frontend::DataType_t::INT64: %d\n", cudnn_frontend::DataType_t::INT64); + printf(">>>>>> cudnn_frontend::DataType_t::DOUBLE: %d\n", cudnn_frontend::DataType_t::DOUBLE); + printf(">>>>>> bias_type: %d\n", bias_type); + printf(">>>>>> mask_type: %d\n", mask_type); + printf(">>>>>> scaling_factor: %f\n", scaling_factor); + printf(">>>>>> dropout_probability: %f\n", dropout_probability); + // qkv_tensor_type = cudnn_frontend::DataType_t::FP8_E8M0; + // o_tensor_type = cudnn_frontend::DataType_t::BFLOAT16; + // printf(">>>>>> after setting qkv_tensor_type and o_tensor_type\n"); + // printf(">>>>>> qkv_tensor_type: %d\n", qkv_tensor_type); + // printf(">>>>>> o_tensor_type: %d\n", o_tensor_type); try { FADescriptor_v1 descriptor{b, @@ -1770,18 +1792,55 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); + printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); + printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); + + int32_t block_size = 32; + int64_t d_scale = (d + block_size - 1) / block_size; + int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t d_scale_padded = ((d_scale + 3) / 4) * 4; + int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; + int64_t d_padded = ((d + 3) / 4) * 4; // d dimension for SF_V (not scaled, but may need padding) + printf(">>>>>> d_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_scale_padded: %d, s_kv_scale_padded: %d, d_padded: %d\n", d_scale, s_kv_scale, s_q_padded, s_kv_padded, d_scale_padded, s_kv_scale_padded, d_padded); + std::vector q_scale_dims = {b, h, s_q_padded, d_scale_padded}; + std::vector k_scale_dims = {b, hg, s_kv_padded, d_scale_padded}; + std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_padded}; + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_scale_padded, q_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_scale_padded, k_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); +// generateMatrixStrides(b, d_padded, s_q_padded, hg, s_kv_scale_padded, v_scale_strides.data(), layout, + // generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, + // NVTE_QKV_Matrix::NVTE_V_Matrix); + v_scale_strides[0] = h*d_padded*s_kv_scale_padded; + v_scale_strides[1] = d_padded*s_kv_scale_padded; + v_scale_strides[2] = 1; + v_scale_strides[3] = s_kv_scale_padded; + printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") @@ -1789,16 +1848,36 @@ void fused_attn_fp8_fwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + if (!is_mxfp8) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + } else { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim(q_scale_dims) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim(k_scale_dims) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim(v_scale_dims) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + } if (is_delayed_scaling) { scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); @@ -1851,8 +1930,24 @@ void fused_attn_fp8_fwd_impl_v1( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( + std::shared_ptr O, Stats, amax_s, amax_o; + if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_o = outputs[2]; + } else { + auto outputs = mha_graph->sdpa_fp8( Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_s = outputs[2]; + amax_o = outputs[3]; + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, @@ -1863,11 +1958,6 @@ void fused_attn_fp8_fwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - Stats->set_output(true) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h, s_q, 1}) @@ -1886,7 +1976,9 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : + std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); @@ -1896,11 +1988,16 @@ void fused_attn_fp8_fwd_impl_v1( : std::make_tuple(nullptr, nullptr); NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + printf(">>>>>> mha_graph->validate()\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + printf(">>>>>> mha_graph->build_operation_graph(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + printf(">>>>>> mha_graph->create_execution_plans({fe::HeurMode_t::A})\n"); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + printf(">>>>>> mha_graph->check_support(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - + printf(">>>>>> mha_graph->build_plans(handle)\n"); + printf(">>>>>> mha_graph->get_workspace_size(): %zu\n", mha_graph->get_workspace_size()); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); @@ -1967,7 +2064,7 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } - + printf(">>>>>> mha_graph->execute(handle, variant_pack, workspace)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2420,16 +2517,44 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = input_Q->data.dptr; - void* devPtrK = input_K->data.dptr; - void* devPtrV = input_V->data.dptr; - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; + void* devPtrQ = nullptr; + void* devPtrK = nullptr; + void* devPtrV = nullptr; + void* devPtrDescaleQ = nullptr; + void* devPtrDescaleK = nullptr; + void* devPtrDescaleV = nullptr; + void* devPtrO = nullptr; + void* devPtrAmaxO = nullptr; + void* devPtrScaleO = nullptr; + void* devPtrAmaxS = nullptr; + void* devPtrScaleS = nullptr; + void* devPtrDescaleS = nullptr; + printf(">>>>>> fused_attn_fp8_fwd\n"); + // if (input_Q->scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // printf(">>>>>> input_Q is MXFP8\n"); + // devPtrQ = input_Q-> + // devPtrQ = input_Q->get_rowwise_data_ptr(); + // devPtrDescaleQ = input_Q->get_rowwise_scale_inv_ptr(); + // devPtrK = input_K->get_rowwise_data_ptr(); + // devPtrDescaleK = input_K->get_rowwise_scale_inv_ptr(); + // devPtrV = input_V->get_rowwise_data_ptr(); + // devPtrDescaleV = input_V->get_rowwise_scale_inv_ptr(); + // devPtrO = output_O->get_rowwise_data_ptr(); + // devPtrAmaxO = output_O->amax.dptr; + // } else { + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + // } void* devPtrM = nullptr; void* devPtrZInv = nullptr; @@ -2458,10 +2583,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); void* devPtrcuSeqlensKV = diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ef7fa0dcc0..5da38045e4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1180,6 +1180,8 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: + print(f">>>>>>> Combining and quantizing q, k, v <<<<<<<") + print(f"q: {q.shape}, k: {k.shape}, v: {v.shape}") q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 5a554d86ec..8699f22cb9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -674,10 +674,10 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False - if not fp8_recipe_dpa.float8_per_tensor_scaling(): - assert not ( - fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha - ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" +# if not fp8_recipe_dpa.float8_per_tensor_scaling(): +# assert not ( +# fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha +# ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 56e6f093d1..2490b5ccd4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -40,6 +40,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -2089,14 +2090,16 @@ def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + columnwise = True QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True - QKV_quantizer.set_usage(rowwise=True, columnwise=False) + QKV_quantizer.set_usage(rowwise=True, columnwise=columnwise) O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) - S_quantizer = quantizers["scaling_fwd"][META_S] - S_quantizer.internal = True - S_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer.set_usage(rowwise=True, columnwise=columnwise) + S_quantizer = None + # quantizers["scaling_fwd"][META_S] + # S_quantizer.internal = True + # S_quantizer.set_usage(rowwise=True, columnwise=columnwise) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True @@ -2107,6 +2110,7 @@ def get_attention_quantizers(fp8, quantizers): dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + print(f"QKV_quantizer: {QKV_quantizer}, O_quantizer: {O_quantizer}, S_quantizer: {S_quantizer}, dQKV_quantizer: {dQKV_quantizer}, dO_quantizer: {dO_quantizer}, dP_quantizer: {dP_quantizer}") return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2160,10 +2164,17 @@ def print_quantizers( type_str = "DS" elif isinstance(q, Float8CurrentScalingQuantizer): type_str = "CS" - print( - f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" - f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" - ) + elif isinstance(q, MXFP8Quantizer): + type_str = "MXFP8" + if type_str in ["DS", "CS"]: + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) + else: + print( + f"{label} >> {names[i]:14s}: {type_str}" + ) def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): @@ -2172,6 +2183,22 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): qkv_layout = qkv_layout.replace("paged_kv_", "") qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype + print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") + if isinstance(qkv_quantizer, MXFP8Quantizer): + print(f"Using MXFP8Quantizer") + qkv_quantizer._internal = False + q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + v_permuted = v.permute(0, 2, 3, 1).contiguous() + v_fp8_permuted = qkv_quantizer(v_permuted) + print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") + # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape + v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) + print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") + print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + return q_fp8, k_fp8, v_fp8 match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..6b2c21013a 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -16,6 +16,7 @@ NVTE_Fused_Attn_Backend, ) from ..quantized_tensor import Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer __all__ = [ @@ -293,9 +294,10 @@ def fused_attn_fwd( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." + if not isinstance(o_quantizer, MXFP8Quantizer): + assert ( + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention." assert ( o_quantizer is not None ), "o_quantizer is required as an input for FP8 fused attention." diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bc22e03097..d97c72c31c 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -274,6 +274,14 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data = std::nullopt); + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c33..094188b6c9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -91,6 +91,22 @@ std::pair quantizer_helper(py::handle quantizer, !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // MXFP8 + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK( + !data.has_value(), + "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + } } return {std::move(te_T), std::move(py_T)}; } @@ -116,6 +132,7 @@ std::vector fused_attn_fwd( auto none = py::none(); + printf(">>>>>>> Creating QKV tensor wrappers <<<<<<<\n"); // create QKV tensor wrappers TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); @@ -123,11 +140,13 @@ std::vector fused_attn_fwd( te_V = makeTransformerEngineTensor(V, none); const DType qkv_type = te_Q.dtype(); + printf(">>>>>> Creating S tensor wrapper <<<<<<<"); // create S tensor TensorWrapper te_S; py::object py_S; std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + printf(">>>>>> Creating O tensor wrapper <<<<<<<\n"); // create O tensor TensorWrapper te_O; py::object py_O; @@ -139,6 +158,7 @@ std::vector fused_attn_fwd( const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); + printf(">>>>>> Creating Bias tensor wrapper <<<<<<<"); // construct NVTE tensors TensorWrapper te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..20820143b0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -940,6 +940,18 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax(const std::vector& shape, + DType dtype, + std::optional data) { + at::Tensor amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); + out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..58d095a4f4 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -84,6 +84,7 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" + print(f"Quantizing tensor: {tensor.shape}") return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: From 23434b5b1d9b7438ab0d2aa862560f832679fdac Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:18:43 -0800 Subject: [PATCH 02/59] semi-working FP8; broken F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 99 +++++++++++-------- .../common/fused_attn/fused_attn_fp8.cu | 49 +++++---- transformer_engine/common/fused_attn/utils.cu | 25 ++++- .../include/transformer_engine/fused_attn.h | 5 + .../common/util/pybind_helper.h | 6 +- .../attention/dot_product_attention/utils.py | 46 ++++++--- .../pytorch/cpp_extensions/fused_attn.py | 2 + 7 files changed, 157 insertions(+), 75 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 4f8367aac7..02ff448544 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -117,6 +117,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_HD_HD_SD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -157,6 +159,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -176,6 +180,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -195,6 +201,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -226,45 +234,58 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - - if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // 8.9: t3hd, max_s=512, d=64, padding - ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && - // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000) && !return_max_logit) { - if (cudnn_runtime_version >= 8900) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { + printf(">>>>>> q_dtype: %d\n", q_dtype); + printf(">>>>>> qkv_format: %d\n", qkv_format); + printf(">>>>>> q_format: %d\n", q_format); + printf(">>>>>> kv_format: %d\n", kv_format); + printf(">>>>>> layout_group: %d\n", layout_group); + printf(">>>>>> cudnn_runtime_version: %d\n", cudnn_runtime_version); + printf(">>>>>> is_training: %d\n", is_training); + printf(">>>>>> bias_type: %d\n", bias_type); + printf(">>>>>> attn_mask_type: %d\n", attn_mask_type); + if (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) { + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + // } + + // if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + // sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // // 8.9: t3hd, max_s=512, d=64, padding + // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + // (cudnn_runtime_version >= 90700 && + // // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // // sm90: fwd d<=256, bwd d=128 only + // // sm100: fwd d<=128, bwd d<=128 + // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + // (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && + // !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + // // 9.10.0: known bugs with SDPA FP8 + // (cudnn_runtime_version != 91000) && !return_max_logit) { + // if (cudnn_runtime_version >= 8900) { + // backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + // } else { + // backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + // std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." + // " Please upgrade your cuDNN version if possible." + // << std::endl; + // } + // } else +} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 400a11af6a..ea5722a831 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1659,8 +1659,8 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); @@ -1673,16 +1673,15 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - // NVTE_CHECK(is_current_scaling || is_delayed_scaling, - // "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - // "kFloat8E5M2!"); - is_current_scaling = false; - is_delayed_scaling = false; - bool is_mxfp8 = true; + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); @@ -1697,6 +1696,7 @@ void fused_attn_fp8_fwd_impl_v1( printf(">>>>>> mask_type: %d\n", mask_type); printf(">>>>>> scaling_factor: %f\n", scaling_factor); printf(">>>>>> dropout_probability: %f\n", dropout_probability); + is_mxfp8 = true; // qkv_tensor_type = cudnn_frontend::DataType_t::FP8_E8M0; // o_tensor_type = cudnn_frontend::DataType_t::BFLOAT16; // printf(">>>>>> after setting qkv_tensor_type and o_tensor_type\n"); @@ -1783,6 +1783,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + printf(">>>>>> layout: %d\n", layout); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -2030,17 +2031,19 @@ void fused_attn_fp8_fwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_s, devPtrDescaleS}, - {scale_s, devPtrScaleS}, {attn_scale, &scaling_factor}, {O, devPtrO}, - {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; if (is_delayed_scaling) { variant_pack[scale_o] = devPtrScaleO; } + if (!is_mxfp8) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[amax_s] = devPtrAmaxS; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -2548,14 +2551,19 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDescaleK = input_K->scale_inv.dptr; devPtrV = input_V->data.dptr; devPtrDescaleV = input_V->scale_inv.dptr; + // devPtrV = input_V->columnwise_data.dptr; + // devPtrDescaleV = input_V->columnwise_scale_inv.dptr; devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; - devPtrScaleO = output_O->scale.dptr; - devPtrAmaxS = input_output_S->amax.dptr; - devPtrScaleS = input_output_S->scale.dptr; - devPtrDescaleS = input_output_S->scale_inv.dptr; + // devPtrScaleO = output_O->scale.dptr; + // devPtrAmaxS = input_output_S->amax.dptr; + // devPtrScaleS = input_output_S->scale.dptr; + // devPtrDescaleS = input_output_S->scale_inv.dptr; // } - + printf(">>>>>> scaling_mode: %d\n", input_Q->scaling_mode); + printf(">>>>>> scaling_mode: %d\n", input_K->scaling_mode); + printf(">>>>>> scaling_mode: %d\n", input_V->scaling_mode); + printf(">>>>>> scaling_mode: %d\n", output_O->scaling_mode); void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { @@ -2604,7 +2612,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, workspace->data.dptr, &workspace_size, + stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 727aac447b..0309cf643d 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,7 +293,30 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; - } + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d * s_kv; + strideA[seqlen_dim_idx] = 1; + strideA[hidden_dim_idx] = s_kv; + // strideA[batch_dim_idx] = h * s_kv * d; + // strideA[head_dim_idx] = s_kv * d; + // strideA[seqlen_transpose_dim_idx] = d; + // strideA[hidden_transpose_dim_idx] = 1; + } + break; +} if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { strideA[seqlen_kv_dim_idx] = 1; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index cddd3d7506..7c54633989 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,6 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ + NVTE_BSHD_BSHD_BHSD = 25, /*!< BSHD_BSHD_BHSD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -70,6 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, + /*! BSHD_BSHD_BHSD QKV layouts, e.g. BSHD_BSHD_BHSD */ + NVTE_HD_HD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -90,6 +93,8 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, + /*! BSHD_BSHD_BHSD QKV format, e.g. BSHD_BSHD_BHSD */ + NVTE_BHSD = 7, }; /*! \enum NVTE_Bias_Type diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6adba23a8f..57d02bcd62 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -48,7 +48,8 @@ .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ - .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -74,7 +75,8 @@ .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ + .value("NVTE_BSHD_BSHD_BHSD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 2490b5ccd4..0dfc64a65c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2090,12 +2090,13 @@ def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 - columnwise = True + is_fwd = True + is_bwd = True QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True - QKV_quantizer.set_usage(rowwise=True, columnwise=columnwise) + QKV_quantizer.set_usage(rowwise=True, columnwise=True) O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=columnwise) + O_quantizer.set_usage(rowwise=True, columnwise=True) S_quantizer = None # quantizers["scaling_fwd"][META_S] # S_quantizer.internal = True @@ -2103,9 +2104,9 @@ def get_attention_quantizers(fp8, quantizers): dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer.set_usage(rowwise=True, columnwise=True) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.set_usage(rowwise=True, columnwise=True) dO_quantizer.internal = True dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) @@ -2181,24 +2182,43 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, _, _ = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): print(f"Using MXFP8Quantizer") qkv_quantizer._internal = False + dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") + dim_others = [i for i in range(len(v.shape)) if i != dim_s] + perm = [*dim_others, dim_s] + # perm = [*dim_others[:-1], dim_s, dim_others[-1]] + v = v.permute(*perm).contiguous() + qkv_layout = "bshd_bshd_bhsd" + # inv = [0] * len(perm) + # for i, p in enumerate(perm): + # inv[p] = i + # v = v.permute(*inv) q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - v_permuted = v.permute(0, 2, 3, 1).contiguous() - v_fp8_permuted = qkv_quantizer(v_permuted) - print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") - # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape - v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) - print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") - print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv) return q_fp8, k_fp8, v_fp8 + + # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + # v_permuted = v.permute(0, 2, 3, 1).contiguous() + # v_fp8_permuted = qkv_quantizer(v_permuted) + # print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") + # print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") + # # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape + # v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) + # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") + # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") + # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + # return q_fp8, k_fp8, v_fp8 match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 6b2c21013a..41007912c9 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,6 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, + "bshd_bshd_bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -70,6 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, + "bshd_bshd_bhsd": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHSD, } AttnBiasType = { From dbb68b8c958735ffe89cca1881bc8d3cd5bfa871 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Feb 2026 19:52:55 -0800 Subject: [PATCH 03/59] clean up last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 5 +- .../common/fused_attn/fused_attn.cpp | 100 +++++++++--------- .../common/fused_attn/fused_attn_fp8.cu | 66 +++++------- transformer_engine/common/fused_attn/utils.cu | 16 ++- .../include/transformer_engine/fused_attn.h | 8 +- .../common/util/pybind_helper.h | 4 +- .../dot_product_attention/backends.py | 21 ++-- .../dot_product_attention.py | 5 - .../attention/dot_product_attention/utils.py | 46 +++----- .../pytorch/cpp_extensions/fused_attn.py | 18 +--- .../pytorch/csrc/extensions/attention.cpp | 4 - 11 files changed, 123 insertions(+), 170 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index dd133c840a..dad697e910 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2113,7 +2113,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - print(f"flash_attn_supported: {flash_attn_supported}, fused_attn_supported: {fused_attn_supported}, unfused_attn_supported: {unfused_attn_supported}") if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -2155,7 +2154,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") - print(f"Running fused attention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) @@ -2166,7 +2164,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - print(f"Running fused attention with fp8_dpa = False") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( dtype, config, False, qkv_layout, is_training, fp8_recipe ) @@ -2188,7 +2185,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if unfused_attn_supported: + if False: #unfused_attn_supported: logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 02ff448544..61a8d61635 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -117,8 +117,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: - return NVTE_QKV_Layout_Group::NVTE_HD_HD_SD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: + return NVTE_QKV_Layout_Group::NVTE_HD_HD_DS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -159,8 +159,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: - return NVTE_QKV_Format::NVTE_BHSD; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: + return NVTE_QKV_Format::NVTE_BHDS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -180,8 +180,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHSD: - return NVTE_QKV_Format::NVTE_BHSD; + case NVTE_QKV_Format::NVTE_BHDS: + return NVTE_QKV_Format::NVTE_BHDS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -201,8 +201,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHSD: - return NVTE_QKV_Format::NVTE_BHSD; + case NVTE_QKV_Format::NVTE_BHDS: + return NVTE_QKV_Format::NVTE_BHDS; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -234,6 +234,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); + printf(">>>>>> qkv_layout: %d\n", qkv_layout); printf(">>>>>> q_dtype: %d\n", q_dtype); printf(">>>>>> qkv_format: %d\n", qkv_format); printf(">>>>>> q_format: %d\n", q_format); @@ -243,49 +244,45 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( printf(">>>>>> is_training: %d\n", is_training); printf(">>>>>> bias_type: %d\n", bias_type); printf(">>>>>> attn_mask_type: %d\n", attn_mask_type); - if (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - // } - - // if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - // sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // // 8.9: t3hd, max_s=512, d=64, padding - // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - // (cudnn_runtime_version >= 90700 && - // // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // // sm90: fwd d<=256, bwd d=128 only - // // sm100: fwd d<=128, bwd d<=128 - // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - // (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && - // !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && - // // 9.10.0: known bugs with SDPA FP8 - // (cudnn_runtime_version != 91000) && !return_max_logit) { - // if (cudnn_runtime_version >= 8900) { - // backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - // } else { - // backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - // std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - // " Please upgrade your cuDNN version if possible." - // << std::endl; - // } - // } else -} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { + + if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 8.9: t3hd, max_s=512, d=64, padding + // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + // (cudnn_runtime_version >= 90700 && + // // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // // sm90: fwd d<=256, bwd d=128 only + // // sm100: fwd d<=128, bwd d<=128 + // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHDS) && + !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + // 9.10.0: known bugs with SDPA FP8 + (cudnn_runtime_version != 91000) && !return_max_logit) { + if (cudnn_runtime_version >= 8900) { + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + } else { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && @@ -1205,6 +1202,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, false); + printf(">>>>>> fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index ea5722a831..bf4f019a67 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1696,12 +1696,6 @@ void fused_attn_fp8_fwd_impl_v1( printf(">>>>>> mask_type: %d\n", mask_type); printf(">>>>>> scaling_factor: %f\n", scaling_factor); printf(">>>>>> dropout_probability: %f\n", dropout_probability); - is_mxfp8 = true; - // qkv_tensor_type = cudnn_frontend::DataType_t::FP8_E8M0; - // o_tensor_type = cudnn_frontend::DataType_t::BFLOAT16; - // printf(">>>>>> after setting qkv_tensor_type and o_tensor_type\n"); - // printf(">>>>>> qkv_tensor_type: %d\n", qkv_tensor_type); - // printf(">>>>>> o_tensor_type: %d\n", o_tensor_type); try { FADescriptor_v1 descriptor{b, @@ -1792,7 +1786,7 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); // need to double check printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); @@ -1804,7 +1798,7 @@ void fused_attn_fp8_fwd_impl_v1( int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; int64_t d_scale_padded = ((d_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_padded = ((d + 3) / 4) * 4; // d dimension for SF_V (not scaled, but may need padding) + int64_t d_padded = ((d + 3) / 4) * 4; printf(">>>>>> d_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_scale_padded: %d, s_kv_scale_padded: %d, d_padded: %d\n", d_scale, s_kv_scale, s_q_padded, s_kv_padded, d_scale_padded, s_kv_scale_padded, d_padded); std::vector q_scale_dims = {b, h, s_q_padded, d_scale_padded}; std::vector k_scale_dims = {b, hg, s_kv_padded, d_scale_padded}; @@ -1816,13 +1810,8 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); -// generateMatrixStrides(b, d_padded, s_q_padded, hg, s_kv_scale_padded, v_scale_strides.data(), layout, - // generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, - // NVTE_QKV_Matrix::NVTE_V_Matrix); - v_scale_strides[0] = h*d_padded*s_kv_scale_padded; - v_scale_strides[1] = d_padded*s_kv_scale_padded; - v_scale_strides[2] = 1; - v_scale_strides[3] = s_kv_scale_padded; + generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); @@ -1977,8 +1966,8 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = - is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); @@ -1997,7 +1986,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); printf(">>>>>> mha_graph->check_support(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - printf(">>>>>> mha_graph->build_plans(handle)\n"); + printf(">>>>>> mha_graph->build_plans(handle)\n"); printf(">>>>>> mha_graph->get_workspace_size(): %zu\n", mha_graph->get_workspace_size()); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); @@ -2532,19 +2521,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrAmaxS = nullptr; void* devPtrScaleS = nullptr; void* devPtrDescaleS = nullptr; - printf(">>>>>> fused_attn_fp8_fwd\n"); - // if (input_Q->scaling_mode() == NVTE_MXFP8_1D_SCALING) { - // printf(">>>>>> input_Q is MXFP8\n"); - // devPtrQ = input_Q-> - // devPtrQ = input_Q->get_rowwise_data_ptr(); - // devPtrDescaleQ = input_Q->get_rowwise_scale_inv_ptr(); - // devPtrK = input_K->get_rowwise_data_ptr(); - // devPtrDescaleK = input_K->get_rowwise_scale_inv_ptr(); - // devPtrV = input_V->get_rowwise_data_ptr(); - // devPtrDescaleV = input_V->get_rowwise_scale_inv_ptr(); - // devPtrO = output_O->get_rowwise_data_ptr(); - // devPtrAmaxO = output_O->amax.dptr; - // } else { + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + printf(">>>>>> input_Q is MXFP8_1D_SCALING\n"); devPtrQ = input_Q->data.dptr; devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; @@ -2555,15 +2533,21 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // devPtrDescaleV = input_V->columnwise_scale_inv.dptr; devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; - // devPtrScaleO = output_O->scale.dptr; - // devPtrAmaxS = input_output_S->amax.dptr; - // devPtrScaleS = input_output_S->scale.dptr; - // devPtrDescaleS = input_output_S->scale_inv.dptr; - // } - printf(">>>>>> scaling_mode: %d\n", input_Q->scaling_mode); - printf(">>>>>> scaling_mode: %d\n", input_K->scaling_mode); - printf(">>>>>> scaling_mode: %d\n", input_V->scaling_mode); - printf(">>>>>> scaling_mode: %d\n", output_O->scaling_mode); + } else { + printf(">>>>>> input_Q is not MXFP8_1D_SCALING\n"); + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { @@ -2605,7 +2589,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHDS)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 0309cf643d..94a495153e 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,7 +293,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { strideA[batch_dim_idx] = s_q * h * d; @@ -310,10 +310,16 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[head_dim_idx] = d * s_kv; strideA[seqlen_dim_idx] = 1; strideA[hidden_dim_idx] = s_kv; - // strideA[batch_dim_idx] = h * s_kv * d; - // strideA[head_dim_idx] = s_kv * d; - // strideA[seqlen_transpose_dim_idx] = d; - // strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = 1; + strideA[hidden_transpose_dim_idx] = h * d; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d * s_kv; + strideA[seqlen_transpose_dim_idx] = s_kv; + strideA[hidden_transpose_dim_idx] = 1; } break; } diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 7c54633989..bc97d2a853 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,7 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ - NVTE_BSHD_BSHD_BHSD = 25, /*!< BSHD_BSHD_BHSD layout */ + NVTE_BSHD_BSHD_BHDS = 25, /*!< BSHD_BSHD_BHDS layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -71,8 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, - /*! BSHD_BSHD_BHSD QKV layouts, e.g. BSHD_BSHD_BHSD */ - NVTE_HD_HD_SD = 6, + /*! BSHD_BSHD_BHDS QKV layouts, e.g. BSHD_BSHD_BHDS */ + NVTE_HD_HD_DS = 6, }; /*! \enum NVTE_QKV_Format @@ -94,7 +94,7 @@ enum NVTE_QKV_Format { /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, /*! BSHD_BSHD_BHSD QKV format, e.g. BSHD_BSHD_BHSD */ - NVTE_BHSD = 7, + NVTE_BHDS = 7, }; /*! \enum NVTE_Bias_Type diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 57d02bcd62..b81c488005 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -49,7 +49,7 @@ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ - .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ + .value("NVTE_BHDS", NVTE_QKV_Format::NVTE_BHDS); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -76,7 +76,7 @@ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ - .value("NVTE_BSHD_BSHD_BHSD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHSD); \ + .value("NVTE_BSHD_BSHD_BHDS", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5da38045e4..cce4159ea6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -35,6 +35,7 @@ restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, @@ -168,7 +169,7 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - q_fp8, k_fp8, v_fp8 = combine_and_quantize( + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( @@ -192,7 +193,7 @@ def backward(ctx, grad1, grad2, grad3): tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] - dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + dq_fp8, dk_fp8, dv_fp8, ctx.qkv_layout = combine_and_quantize( ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer ) tensors = combine_and_dequantize( @@ -1180,9 +1181,7 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - print(f">>>>>>> Combining and quantizing q, k, v <<<<<<<") - print(f"q: {q.shape}, k: {k.shape}, v: {v.shape}") - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) # print quantizers print_quantizers( @@ -1237,11 +1236,17 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - - if isinstance(out_, Float8Tensor): + print(f"out_: {type(out_)} {out_.shape}") + print(f"is_output_fp8: {is_output_fp8}") + print(f"is_bwd_fp8: {is_bwd_fp8}") + print(f"fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}") + print(f"_dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") + if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): + print(f"dequantizing out_") if not is_output_fp8 or not is_bwd_fp8: out = out_.dequantize().view(out_.shape) else: + print(f"quantizing out_") if is_output_fp8 or ( is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) @@ -1562,7 +1567,7 @@ def backward(ctx, d_out, *_args): ) if not is_float8tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv = combine_and_quantize( + dq, dk, dv, ctx.qkv_layout = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 8699f22cb9..5da35d157f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -673,11 +673,6 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False -# if not fp8_recipe_dpa.float8_per_tensor_scaling(): -# assert not ( -# fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha -# ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 0dfc64a65c..eaeecaca4b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2090,28 +2090,27 @@ def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 - is_fwd = True - is_bwd = True QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=True) O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=True) - S_quantizer = None - # quantizers["scaling_fwd"][META_S] - # S_quantizer.internal = True - # S_quantizer.set_usage(rowwise=True, columnwise=columnwise) + O_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(QKV_quantizer, MXFP8Quantizer): + S_quantizer = None + else: + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=True) + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=True) + dO_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer.internal = True dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True - print(f"QKV_quantizer: {QKV_quantizer}, O_quantizer: {O_quantizer}, S_quantizer: {S_quantizer}, dQKV_quantizer: {dQKV_quantizer}, dO_quantizer: {dO_quantizer}, dP_quantizer: {dP_quantizer}") return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2187,38 +2186,25 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): src_nominal_dtype = q.dtype print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): - print(f"Using MXFP8Quantizer") qkv_quantizer._internal = False dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") dim_others = [i for i in range(len(v.shape)) if i != dim_s] perm = [*dim_others, dim_s] # perm = [*dim_others[:-1], dim_s, dim_others[-1]] v = v.permute(*perm).contiguous() - qkv_layout = "bshd_bshd_bhsd" - # inv = [0] * len(perm) - # for i, p in enumerate(perm): - # inv[p] = i + qkv_layout = "bshd_bshd_bhds" + inv = [0] * len(perm) + for i, p in enumerate(perm): + inv[p] = i # v = v.permute(*inv) q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv) - return q_fp8, k_fp8, v_fp8 - - # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - # v_permuted = v.permute(0, 2, 3, 1).contiguous() - # v_fp8_permuted = qkv_quantizer(v_permuted) - # print(f"v_fp8: {v_fp8._rowwise_scale_inv.shape}") - # print(f"v_fp8_permuted: {v_fp8_permuted._rowwise_scale_inv.shape}") - # # v_fp8_permuted_rowwise_shape = v_fp8._rowwise_scale_inv.permute(0, 2, 3, 1).shape - # v_fp8._rowwise_scale_inv = v_fp8_permuted._rowwise_scale_inv.view(2,16,128,-1).permute(0, 3, 1, 2)#.view(-1,128) - # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}, v_fp8.stride(): {v_fp8._rowwise_data.stride()}") - # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}, v_fp8.stride(): {v_fp8._rowwise_scale_inv.stride()}") - # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - # return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout + match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 41007912c9..2748228b42 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,7 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, - "bshd_bshd_bhsd": NVTE_QKV_Format.NVTE_BHSD, + "bshd_bshd_bhds": NVTE_QKV_Format.NVTE_BHDS, } QKVLayout = { @@ -71,7 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, - "bshd_bshd_bhsd": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHSD, + "bshd_bshd_bhds": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHDS, } AttnBiasType = { @@ -295,14 +295,6 @@ def fused_attn_fwd( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - if not isinstance(o_quantizer, MXFP8Quantizer): - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." - assert ( - o_quantizer is not None - ), "o_quantizer is required as an input for FP8 fused attention." else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -501,12 +493,6 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention backward." - assert ( - dp_quantizer is not None - ), "dp_quantizer is required as an input for FP8 fused attention backward." assert ( dqkv_dtype is not None ), "dqkv_dtype is required as an input for FP8 fused attention backward." diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 094188b6c9..0d7a842ce1 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -132,7 +132,6 @@ std::vector fused_attn_fwd( auto none = py::none(); - printf(">>>>>>> Creating QKV tensor wrappers <<<<<<<\n"); // create QKV tensor wrappers TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); @@ -140,13 +139,11 @@ std::vector fused_attn_fwd( te_V = makeTransformerEngineTensor(V, none); const DType qkv_type = te_Q.dtype(); - printf(">>>>>> Creating S tensor wrapper <<<<<<<"); // create S tensor TensorWrapper te_S; py::object py_S; std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); - printf(">>>>>> Creating O tensor wrapper <<<<<<<\n"); // create O tensor TensorWrapper te_O; py::object py_O; @@ -158,7 +155,6 @@ std::vector fused_attn_fwd( const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); - printf(">>>>>> Creating Bias tensor wrapper <<<<<<<"); // construct NVTE tensors TensorWrapper te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; From c627231a4f094acefdc2cc42e3b38c636aa0f095 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Feb 2026 20:13:32 -0800 Subject: [PATCH 04/59] comment out F16 pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index dad697e910..0301f77ae8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2158,15 +2158,15 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - if config.dropout_p == 0.0: - # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training, fp8_recipe - ) + # os.environ["NVTE_FLASH_ATTN"] = "0" + # os.environ["NVTE_FUSED_ATTN"] = "1" + # os.environ["NVTE_UNFUSED_ATTN"] = "0" + # if config.dropout_p == 0.0: + # # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + # fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + # dtype, config, False, qkv_layout, is_training, fp8_recipe + # ) atol = 5e-1 rtol = 5e-2 From 3f3b9e64f09bd4028157ded1b0bfd66157af6a72 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Feb 2026 14:46:37 -0800 Subject: [PATCH 05/59] pull in grouped_quantize for MXFP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 254 ++++++++++-------- transformer_engine/common/common.h | 5 +- 2 files changed, 148 insertions(+), 111 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 7801a2064d..df4317b547 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -21,6 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "../core/common.cuh" +#include "swizzle.cuh" namespace transformer_engine { namespace dispatch { @@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso template + bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES> __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, @@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; + using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; + if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { return; @@ -475,8 +478,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + // const size_t scale_idx = + // global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -602,7 +612,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + // const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } scales_rowwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -738,7 +755,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations using namespace group_quantize_kernel; checkCuDriverContext(stream); - CheckNoopTensor(*noop, "cast_noop"); + // CheckNoopTensor(*noop, "cast_noop"); const bool use_rowwise_scaling = output->has_data(); const bool use_colwise_scaling = output->has_columnwise_data(); @@ -751,6 +768,13 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } else if (!use_rowwise_scaling) { scaling_type = ScalingType::COLWISE; } + // if (use_rowwise_scaling && (!use_colwise_scaling)) { + // scaling_type = ScalingType::ROWWISE; + // } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + // scaling_type = ScalingType::COLWISE; + // } else if (use_rowwise_scaling && use_colwise_scaling) { + // scaling_type = ScalingType::BIDIMENSIONAL; + // } ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; if (output->all_same_shape()) { @@ -827,6 +851,12 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } + const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { @@ -848,111 +878,115 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations input->dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - auto kernel = group_quantize_mxfp8_kernel; - switch (scaling_type) { - case ScalingType::ROWWISE: { - kernel = group_quantize_mxfp8_kernel; - break; - } - case ScalingType::COLWISE: { - kernel = group_quantize_mxfp8_kernel; - break; - } - case ScalingType::BIDIMENSIONAL: { - kernel = group_quantize_mxfp8_kernel; - break; - } - } - - // Update tensor descriptors before launching the kernel - if (!is_single_tensor) { - const IType *const input_dptr = reinterpret_cast(input->data.dptr); - - const IType *const act_input_dptr = - IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; - - OType *const output_rowwise_dptr = - use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; - - OType *const output_colwise_dptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - update_tma_descriptors<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, - output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, - use_colwise_scaling, IS_DACT); - } - - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - - if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - } - - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + auto kernel = group_quantize_mxfp8_kernel; + switch (scaling_type) { + case ScalingType::ROWWISE: { + kernel = group_quantize_mxfp8_kernel; + break; + } + case ScalingType::COLWISE: { + kernel = group_quantize_mxfp8_kernel; + break; + } + case ScalingType::BIDIMENSIONAL: { + kernel = group_quantize_mxfp8_kernel; + break; + } + } + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, + use_colwise_scaling, IS_DACT); + } + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError()); + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 970b7aef6c..66b7e30187 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -333,6 +333,7 @@ struct GroupedTensor { NVTEScalingMode scaling_mode; size_t num_tensors; NVTEGroupedTensor nvte_tensor; + bool with_gemm_swizzled_scales = false; GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) : data(), @@ -348,7 +349,8 @@ struct GroupedTensor { tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), scaling_mode(scaling_mode), - nvte_tensor(0) {} + nvte_tensor(0), + with_gemm_swizzled_scales(false) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } @@ -400,6 +402,7 @@ struct GroupedTensor { num_tensors = 0; scaling_mode = NVTE_DELAYED_TENSOR_SCALING; nvte_tensor = 0; + with_gemm_swizzled_scales = false; } }; From 850b16e72eef8d5e2cd6a0ef7378e100b165903c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:13:52 -0800 Subject: [PATCH 06/59] grouped tensor - pytorch Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_grouped_tensor.py | 450 ++++++++ transformer_engine/common/cast/cast.cu | 2 +- .../common/cast/dispatch/quantize.cuh | 1 + .../cast/mxfp8/group_quantize_mxfp8.cuh | 2 + .../common/cast/mxfp8/quantize_mxfp8.cuh | 1 + .../transformer_engine/transformer_engine.h | 236 +++++ transformer_engine/common/recipe/__init__.py | 34 +- .../common/transformer_engine.cpp | 114 +++ .../attention/dot_product_attention/utils.py | 5 + transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/cast.cpp | 13 + .../pytorch/csrc/extensions/pybind.cpp | 4 + transformer_engine/pytorch/csrc/pybind.h | 2 + .../pytorch/csrc/type_converters.cpp | 121 +++ .../pytorch/tensor/mxfp8_tensor.py | 45 +- .../pytorch/tensor/storage/__init__.py | 1 + .../pytorch/tensor/storage/grouped_tensor.py | 964 ++++++++++++++++++ 17 files changed, 1981 insertions(+), 16 deletions(-) create mode 100644 tests/pytorch/test_grouped_tensor.py create mode 100644 transformer_engine/pytorch/tensor/storage/grouped_tensor.py diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py new file mode 100644 index 0000000000..964c2d8e97 --- /dev/null +++ b/tests/pytorch/test_grouped_tensor.py @@ -0,0 +1,450 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for GroupedTensor class""" + +from typing import List, Tuple +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch import ( + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, +) +from transformer_engine.pytorch.constants import TE_DType_To_Torch +import transformer_engine_torch as tex + +# Check available recipes +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +_quantization_params = [ + pytest.param( + "fp8_delayed_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_blockwise", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + ), +] + + +def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: + """Create quantizers for given quantization scheme""" + + if quantization == "fp8_delayed_scaling": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device="cuda"), + amax=torch.zeros(1, dtype=torch.float32, device="cuda"), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + quantizer.set_usage(rowwise=True, columnwise=False) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization == "nvfp4": + quantizer = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + else: + raise ValueError(f"Unknown quantization scheme: {quantization}") + + quantizer.internal = False + + return quantizer + + +def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"): + return qtensor._data + if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"): + return qtensor._rowwise_data + raise ValueError(f"Unknown quantization scheme: {quantization}") + + +def _rowwise_offset_bytes(numel: int, quantization: str) -> int: + if quantization == "nvfp4": + return numel // 2 + return numel + + +class TestGroupedTensor: + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_basic_construction_all_same_shape(self) -> None: + """Test GroupedTensor construction with all tensors having same shape""" + num_tensors = 4 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.all_same_shape() + assert grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.logical_shape == (num_tensors * 256, 512) + assert grouped_tensor.get_common_first_dim() == 256 + assert grouped_tensor.get_common_last_dim() == 512 + assert grouped_tensor.has_data() + + def test_basic_construction_varying_first_dim(self) -> None: + """Test GroupedTensor construction with varying first dimension""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert not grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.logical_shape == ( + sum(v for v, _ in shape), + shape[0][1], + ) # sum of first dims + + def test_split_into_quantized_tensors_no_quantization(self) -> None: + """Test split_into_quantized_tensors for unquantized tensors""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor has correct shape and shares storage + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + assert isinstance(tensor, torch.Tensor) + assert not hasattr(tensor, "_data") # Not a quantized tensor + + # Verify data pointer is within the original grouped tensor storage + # The tensor should be a view of the original data + assert tensor.data_ptr() >= original_data_ptr + + # Calculate expected offset + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: + """Test split_into_quantized_tensors for quantized tensors""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor shares storage with the grouped tensor + for i, tensor in enumerate(tensors): + rowwise_data = _get_rowwise_data_tensor(tensor, quantization) + assert rowwise_data is not None + assert rowwise_data.data_ptr() >= original_data_ptr + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_split_varying_shapes(self) -> None: + """Test split_into_quantized_tensors with varying shapes""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + original_data_ptr = grouped_tensor.data.data_ptr() + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify shapes and storage + cumulative_offset = 0 + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + expected_offset = cumulative_offset * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + cumulative_offset += shape[i][0] * shape[i][1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_inplace(self, quantization: str) -> None: + """Test that quantize is done in-place for all recipes""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers before quantization + original_data_ptr = grouped_tensor.data.data_ptr() + original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() + original_scale_ptr = ( + grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None + ) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointers haven't changed (in-place operation) + assert grouped_tensor.data.data_ptr() == original_data_ptr + assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr + if original_scale_ptr is not None: + assert grouped_tensor.scale.data_ptr() == original_scale_ptr + + # Verify returned tensors point to the same storage + for i, qtensor in enumerate(quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_varying_shapes(self, quantization: str) -> None: + """Test quantize with varying shapes""" + num_tensors = 3 + shape = [(256, 512), (512, 512), (768, 512)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers + original_data_ptr = grouped_tensor.data.data_ptr() + + # Create input tensors with varying shapes + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointer hasn't changed + assert grouped_tensor.data.data_ptr() == original_data_ptr + + # Verify each tensor points to correct location + cumulative_numel = 0 + for qtensor, tensor_shape in zip(quantized_tensors, shape): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + cumulative_numel += tensor_shape[0] * tensor_shape[1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_static_quantize_method(self, quantization: str) -> None: + """Test the static quantize method""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Use static quantize method + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=input_tensors, + quantizer=quantizers, + device="cuda", + ) + + # Verify the grouped tensor was created correctly + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.has_data() + + # Verify quantized_tensors were created and point to same storage + assert grouped_tensor.quantized_tensors is not None + assert len(grouped_tensor.quantized_tensors) == num_tensors + + original_data_ptr = grouped_tensor.data.data_ptr() + for i, qtensor in enumerate(grouped_tensor.quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: + """Test grouped quantization for MXFP8 against per-tensor quantization.""" + # Test wont pass until the grouped quantization PR from Oleg is merged. + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a grouped tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] + for q in quantizers: + q.optimize_for_gemm=True + quantized_tensors = [q(tensor) for q, tensor in zip(quantizers, input_tensors)] + grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.bfloat16, + ) + + offset = 0 + for tensor in input_tensors: + numel = tensor.numel() + grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + # Create MXFP8 output grouped tensor (rowwise only for easier validation) + # quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] + + grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizers=quantizers, + device="cuda", + ) + print(f">>>>>>>>>>>> tex.quantize_grouped") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.data.shape if grouped_input.data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.scale_inv.shape if grouped_input.scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_data.shape if grouped_input.columnwise_data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") + print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_scale_inv.shape if grouped_input.columnwise_scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") + # Quantize using grouped API (handle both 2-arg and 3-arg bindings) + _ = tex.quantize_grouped(grouped_input, grouped_output) + # Build expected output by quantizing each tensor independently + expected_data = [] + expected_scale_inv = [] + for tensor, quantizer in zip(input_tensors, quantizers): + qtensor = quantizer(tensor) + expected_data.append(qtensor._rowwise_data.reshape(-1)) + expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) + + expected_data = torch.cat(expected_data) + expected_scale_inv = torch.cat(expected_scale_inv) + + assert torch.equal(grouped_output.data, expected_data) + assert torch.equal(grouped_output.scale_inv, expected_scale_inv) + + def test_clear(self) -> None: + """Test clear method""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.has_data() + assert grouped_tensor.num_tensors == num_tensors + + grouped_tensor.clear() + + assert not grouped_tensor.has_data() + assert grouped_tensor.num_tensors == 0 + assert grouped_tensor.data is None + assert grouped_tensor.logical_shape == (0, 0) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 582172a88e..624b0bfc7c 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,7 +30,7 @@ void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; - + printf(">>>>>>>>>>>> nvte_group_quantize\n"); constexpr bool IS_ACT = false; dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index b83df1dedf..9a6e9b01d6 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -375,6 +375,7 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, template void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + printf(">>>>>>>>>>>> group_quantize_fwd_helper\n"); using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index df4317b547..35e605067d 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -244,6 +244,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { +printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -852,6 +853,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + printf(">>>>>>>>>>>> with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 70a68132ad..a3e7db94d1 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -55,6 +55,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { +printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..8f3025a86a 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -449,6 +449,7 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ + kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ kNVTENumGroupedTensorParams }; @@ -499,6 +500,25 @@ void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorP NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name); +/*! \brief Set a parameter of the grouped tensor. + * + * \param[in/out] tensor Grouped tensor. + * \param[in] param_name The parameter to be set. + * \param[in] param The value to be set (NVTEBasicTensor). + */ +void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, + const void *buf, size_t size_in_bytes); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get a value of the parameter of the grouped tensor. + * + * \param[in] tensor Grouped tensor. + * \param[in] param_name The parameter to be queried. + * + * \return NVTEBasicTensor containing the parameter data. + */ +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, void *buf, size_t size_in_bytes, size_t *size_written); + /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. * @@ -957,6 +977,222 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ + + class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { + const auto val = static_cast(with_gemm_swizzled_scales); + nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val)); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + bool get_with_gemm_swizzled_scales() const { + uint8_t val = 0; + nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), nullptr); + return static_cast(val); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; + }; + /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..da1bf03b02 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -87,32 +87,38 @@ class Recipe: """ Base recipe class. """ - - def nvfp4(self): + @classmethod + def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" - return isinstance(self, NVFP4BlockScaling) + return issubclass(cls, NVFP4BlockScaling) - def mxfp8(self): + @classmethod + def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" - return isinstance(self, MXFP8BlockScaling) + return issubclass(cls, MXFP8BlockScaling) - def delayed(self): + @classmethod + def delayed(cls): """Whether the given recipe is delayed scaling.""" - return isinstance(self, DelayedScaling) + return issubclass(cls, DelayedScaling) - def float8_current_scaling(self): + @classmethod + def float8_current_scaling(cls): """Whether the given recipe is (per-tensor) current scaling.""" - return isinstance(self, Float8CurrentScaling) + return issubclass(cls, Float8CurrentScaling) - def float8_per_tensor_scaling(self): + @classmethod + def float8_per_tensor_scaling(cls): """Whether the given recipe is per-tensor scaling.""" - return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + return issubclass(cls, (DelayedScaling, Float8CurrentScaling)) - def float8_block_scaling(self): + @classmethod + def float8_block_scaling(cls): """Whether the given recipe is float8 blockwise scaling.""" - return isinstance(self, Float8BlockScaling) + return issubclass(cls, Float8BlockScaling) - def custom(self): + @classmethod + def custom(cls): """Whether the given recipe is custom.""" return isinstance(self, CustomRecipe) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 06971443dd..d0d6b533c8 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1268,3 +1268,117 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); return t.logical_shape; } + +void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, const void *buf, + size_t size_in_bytes) { +// Check attribute and buffer +NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), +")"); +NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); +auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + +// Read from buffer +switch (param) { +case kNVTEGroupedRowwiseData: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.data = *basic_tensor; +break; +} +case kNVTEGroupedColumnwiseData: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.columnwise_data = *basic_tensor; +break; +} +case kNVTEGroupedScale: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.scale = *basic_tensor; +break; +} +case kNVTEGroupedAmax: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.amax = *basic_tensor; +break; +} +case kNVTEGroupedRowwiseScaleInv: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.scale_inv = *basic_tensor; +break; +} +case kNVTEGroupedColumnwiseScaleInv: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.columnwise_scale_inv = *basic_tensor; +break; +} +case kNVTEGroupedColumnwiseAmax: { +const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +t.columnwise_amax = *basic_tensor; +break; +} +case kNVTEGroupedWithGEMMSwizzledScales: +t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); +break; +default: +NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); +} +} + +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, void *buf, + size_t size_in_bytes, size_t *size_written) { +using namespace transformer_engine; + +// Check param +NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), +")"); + +// Return immediately if buffer is not provided +if (buf == nullptr) { +return; +} + +// Get C++ tensor +const GroupedTensor *t = convertNVTEGroupedTensor(tensor); + +// Write to buffer +switch (param) { +case kNVTEGroupedRowwiseData: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->data); +break; +} +case kNVTEGroupedColumnwiseData: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->columnwise_data); +break; +} +case kNVTEGroupedScale: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->scale); +break; +} +case kNVTEGroupedAmax: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->amax); +break; +} +case kNVTEGroupedRowwiseScaleInv: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->scale_inv); +break; +} +case kNVTEGroupedColumnwiseScaleInv: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->columnwise_scale_inv); +break; +} +case kNVTEGroupedColumnwiseAmax: { +NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); +*basic_tensor = static_cast(t->columnwise_amax); +break; +} +case kNVTEGroupedWithGEMMSwizzledScales: +*reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); +break; +default: +NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); +} +} diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index a957976235..78083c0b0b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2188,6 +2188,11 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): src_nominal_dtype = q.dtype print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): + # bs3hd -> bshd_bshd_bhsd + q,k,v = [x.contiguous() for x in [q, k, v]] + + # bshd_bshd_bhsd -> bhsd_bhsd_bhsd + qkv_quantizer.optimize_for_gemm = True qkv_quantizer._internal = False dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") dim_others = [i for i in range(len(v.shape)) if i != dim_s] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..d91ec308fa 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -250,6 +250,8 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); +py::object quantize_grouped(const py::handle &input, py::handle &output); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5c9d0f5b07..34565bcf44 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -80,6 +80,19 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +py::object quantize_grouped(const py::handle &input, py::handle &output) { + using namespace transformer_engine::pytorch::detail; + init_extension(); + printf(">>>>>>>>>>>> quantize_grouped\n"); + const auto &grouped_input_tensor = GroupedTensorFromPyTorchGroupedTensor(input); + const auto &grouped_output_tensor = GroupedTensorFromPyTorchGroupedTensor(output); + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor.data(), at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(output); +} + py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..12abd503cf 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -122,6 +122,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("quantize_grouped", transformer_engine::pytorch::quantize_grouped, "Quantize grouped tensor", + py::arg("input"), + py::arg("output")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 25ffef0588..9541409c0c 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -95,6 +95,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3f998bb66f..8ab8dc1d48 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -170,6 +170,127 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return DType::kFloat8E8M0; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + return DType::kFloat32; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return DType::kFloat8E4M3; + } + return GetTransformerEngineDType(scale_inv.scalar_type()); +} + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + py::handle quantizer = py::none(); + DType quantizer_dtype = DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + bool with_gemm_swizzled_scales = false; + if (!tensor.attr("quantizers").is_none()) { + const auto quantizers = tensor.attr("quantizers").cast(); + quantizer = quantizers[0]; + if (!quantizers.empty() && !quantizer.is_none()) { + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); + printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); + } + } + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("data").is_none()) { + const auto &data = tensor.attr("data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); + } + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 58d095a4f4..a283b43908 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -165,6 +165,49 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def get_scale_shape( + self, + shape: Iterable[int], + columnwise: bool, + ) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For MXFP8 1D blockwise quantization, blocksize is 32 + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + if columnwise: + # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # with padding to multiples of [4, 128] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + ) + # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # with padding to multiples of [128, 4] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ) + + def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization.""" + return rowwise_data_shape + def create_tensor_from_data( self, data: torch.Tensor, @@ -705,7 +748,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=fp8_dtype, dtype=param_dtype, - shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, with_gemm_swizzled_scales=False, ) diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index d7a2719200..54ed5caa60 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,3 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 +from .grouped_tensor import GroupedTensor # noqa: F401 \ No newline at end of file diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py new file mode 100644 index 0000000000..ad85a448e6 --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -0,0 +1,964 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations +from typing import Optional, Tuple, List, Union +import math + +import torch + +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ..mxfp8_tensor import MXFP8Tensor +from ..nvfp4_tensor import NVFP4Tensor +from ..float8_tensor import Float8Tensor +from ..float8_blockwise_tensor import Float8BlockwiseQTensor +from .float8_tensor_storage import Float8TensorStorage +from .mxfp8_tensor_storage import MXFP8TensorStorage +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .nvfp4_tensor_storage import NVFP4TensorStorage + + +class GroupedTensor: + """ + EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. + + Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode. + + Shape Representation: + - logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors + are flattened to 2D and stacked together (REQUIRED) + + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N) + + When varying_first_dim(): [~sum_of_first_dims, N] where N is common + + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common + + When varying_both_dims(): [1, total_elements] (fully flattened) + + - first_dims and last_dims are OPTIONAL (None if dimension is uniform) + + None first_dims: all tensors have the same first dimension + + None last_dims: all tensors have the same last dimension + + Both None: all tensors have identical shapes + + Both set: each tensor has unique shape (first_dims[i], last_dims[i]) + + Data Layout: + - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.) + - logical_shape provides the conceptual 2D interpretation + - All data is stored on device in contiguous layout + + Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. + """ + + def __init__( + self, + num_tensors: int, + shape: List[Tuple[int, int]], + quantizers: Optional[List[Quantizer]] = None, + dtype: Optional[torch.dtype] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + logical_shape: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initialize a GroupedTensor. + + Args: + num_tensors: Number of tensors in the group + shape: 2D shape of each tensor (len num_tensors) + quantizers: List of Quantizers for the grouped tensor + data: Row-wise data buffer (1D flattened) + columnwise_data: Column-wise data buffer (1D flattened) + scale_inv: Row-wise scale inverse buffer + columnwise_scale_inv: Column-wise scale inverse buffer + amax: Row-wise amax buffer + columnwise_amax: Column-wise amax buffer + scale: Scale buffer (for FP8-DS only) + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) + offsets: Vector of integer offsets for each tensor. + logical_shape: 2D tuple representing conceptual shape + """ + print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") + print(f">>>>>>>>>>>> shape: {shape}") + print(f">>>>>>>>>>>> dtype: {dtype}") + print(f">>>>>>>>>>>> data: {data.shape}") + print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") + print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") + print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") + print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") + print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") + print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") + print(f">>>>>>>>>>>> first_dims: {first_dims}") + print(f">>>>>>>>>>>> last_dims: {last_dims}") + print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets}") + print(f">>>>>>>>>>>> offsets: {offsets}") + print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets}") + print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets}") + print(f">>>>>>>>>>>> logical_shape: {logical_shape}") + print(f">>>>>>>>>>>> num_tensors: {num_tensors}") + + self.num_tensors = num_tensors + self.quantizers = quantizers + self.shape = shape + self.dtype = ( + dtype if dtype is not None else torch.float32 + ) # Default to float32 if not provided + + # Data buffers + self.data = data + self.columnwise_data = columnwise_data + self.scale_inv = scale_inv + self.columnwise_scale_inv = columnwise_scale_inv + self.amax = amax + self.columnwise_amax = columnwise_amax + self.scale = scale + + # For convenient indexing for python GroupedTensor API. + self.scale_inv_offsets = scale_inv_offsets + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + + # Shape information (OPTIONAL - None if dimension is uniform across all tensors) + # first_dims[i] = first dimension of tensor i (None if all tensors have same first dim) + # last_dims[i] = last dimension of tensor i (None if all tensors have same last dim) + self.first_dims = ( + first_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + self.last_dims = ( + last_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + + # Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape()) + # tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) + # Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size + # If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) + self.tensor_offsets = ( + tensor_offsets # Device pointer to int64_t array of length num_tensors (or None) + ) + self.offsets = offsets # Vector of integer offsets for each tensor. + + # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) + # Represents how the 1D flattened data should be interpreted as 2D + # Always 2D with positive dimensions + self.logical_shape = logical_shape if logical_shape is not None else (0, 0) + + # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. + # Used as a convenience. + self.quantized_tensors = None + + def has_data(self) -> bool: + """ + Check if the tensor has row-wise data. + + Returns: + True if data buffer is initialized, False otherwise + """ + return self.data is not None + + def has_columnwise_data(self) -> bool: + """ + Check if the tensor has column-wise data. + + Returns: + True if columnwise_data buffer is initialized, False otherwise + """ + return self.columnwise_data is not None + + def all_same_first_dim(self) -> bool: + """ + Check if all tensors in the group have the same first dimension. + + Returns: + True if first dimension is uniform across all tensors + """ + return self.first_dims is None + + def all_same_last_dim(self) -> bool: + """ + Check if all tensors in the group have the same last dimension. + + Returns: + True if last dimension is uniform across all tensors + """ + return self.last_dims is None + + def all_same_shape(self) -> bool: + """ + Check if all tensors in the group have identical shapes. + + Returns: + True if all tensors have the same shape + """ + return self.first_dims is None and self.last_dims is None + + def varying_both_dims(self) -> bool: + """ + Check if both dimensions vary across tensors. + + Returns: + True if both first and last dimensions vary + """ + return self.first_dims is not None and self.last_dims is not None + + def get_common_first_dim(self) -> int: + """ + Get the common first dimension when all tensors share it. + + Returns: + The common first dimension + + Raises: + RuntimeError: If first dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_first_dim(): + raise RuntimeError("First dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + if self.all_same_shape(): + # When both dims are uniform: logical_shape = [num_tensors * M, N] + return self.logical_shape[0] // self.num_tensors + # When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims] + return self.logical_shape[0] + + def get_common_last_dim(self) -> int: + """ + Get the common last dimension when all tensors share it. + + Returns: + The common last dimension + + Raises: + RuntimeError: If last dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_last_dim(): + raise RuntimeError("Last dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + # For both uniform and varying first dim cases: logical_shape[1] is the common last dim + return self.logical_shape[1] + + def get_dtype(self) -> torch.dtype: + """ + Get the high precision data type of the tensor. + + Returns: + The high precision dtype of the data buffer + """ + + return self.dtype + + def clear(self) -> None: + """ + Reset tensor data and clear all buffers. + """ + self.data = None + self.columnwise_data = None + self.scale_inv = None + self.columnwise_scale_inv = None + self.amax = None + self.columnwise_amax = None + self.scale = None + self.first_dims = None + self.last_dims = None + self.tensor_offsets = None + self.logical_shape = (0, 0) + self.num_tensors = 0 + self.quantizers = None + self.quantized_tensors = None + self.offsets = None + self.scale_inv_offsets = None + self.columnwise_scale_inv_offsets = None + + def __repr__(self) -> str: + """String representation of the GroupedTensor.""" + return ( + f"GroupedTensor(num_tensors={self.num_tensors}, " + f"shape={self.shape}, " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()})" + ) + + def __str__(self) -> str: + """User-friendly string representation.""" + shape_info = [] + if self.all_same_shape(): + shape_info.append("uniform shape") + else: + if not self.all_same_first_dim(): + shape_info.append("varying first dim") + if not self.all_same_last_dim(): + shape_info.append("varying last dim") + + return ( + f"GroupedTensor with {self.num_tensors} tensors " + f"({', '.join(shape_info) if shape_info else 'uniform'}), " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()}" + ) + + @staticmethod + def make_grouped_tensor_with_shapes( + num_tensors: int, + shape: List[Tuple[int, int]], + quantizers: Optional[List[Quantizer]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + shape: 2D shape of each tensor (len num_tensors) + quantizers: List of Quantizers for each tensor + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # First dim + first_dim_list = [s[0] for s in shape] + uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) + logical_first_dim = sum(first_dim_list) + if uniform_first_dim: + first_dims = None + else: + first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + + # Last dim + last_dim_list = [s[1] for s in shape] + logical_last_dim = last_dim_list[0] + assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" + + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=first_dims, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=logical_last_dim, + quantizers=quantizers, + device=device, + dtype=dtype, + ) + + @staticmethod + def make_grouped_tensor( + num_tensors: int, + first_dims: Optional[torch.Tensor], + last_dims: Optional[torch.tensor], + logical_first_dim: int, + logical_last_dim: int, + quantizers: Optional[List[Quantizer]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + logical_first_dim: Logical first dimension + logical_last_dim: Logical last dimension + quantizers: List of Quantizers for each tensor + Used to figure out the recipe and what to allocate. + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # Set device + if device is None: + device = torch.cuda.current_device() + + # Shape patterns and validation. + all_same_first = first_dims is None + all_same_last = last_dims is None + + assert all_same_last, "Last dim must be uniform for GroupedTensor" + assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor" + assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor" + + # assert ( + # logical_first_dim % 128 == 0 + # ), "Logical first dim must be divisible by 128" + # assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128" + + # Calculate tensor offsets (cumulative element offsets) + tensor_offsets = None + offsets = None + shape = [] + if not all_same_first: + # Need explicit offsets for non-uniform shapes + # Offsets are based on number of elements and not pointers. + # Kernels need to calculate precise pointers based on size of elements. + + # TODO(ksivaman): Single kernel + remove the host offset calculation. + tensor_offsets = torch.cat( + [ + torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), + torch.cumsum(first_dims * logical_last_dim, dim=0), + ] + ) + offsets = tensor_offsets.tolist() + first_dims_list = first_dims.tolist() + for i in range(num_tensors): + shape.append((first_dims_list[i], logical_last_dim)) + else: + offsets = [ + i * logical_first_dim * logical_last_dim // num_tensors + for i in range(num_tensors + 1) + ] + for i in range(num_tensors): + shape.append((logical_first_dim // num_tensors, logical_last_dim)) + + # Calculate logical shape based + logical_shape = (logical_first_dim, logical_last_dim) + + quantizer = quantizers[0] if isinstance(quantizers, list) else quantizers + print(f">>>>>>>>>>>>> quantizers: {quantizers}") + print(f">>>>>>>>>>>>> quantizer: {quantizer}") + no_quantization = quantizer is None + + rowwise_usage = quantizer.rowwise_usage if not no_quantization else True + columnwise_usage = quantizer.columnwise_usage if not no_quantization else False + + # Calculate total elements across all tensors + total_elements = logical_first_dim * logical_last_dim + + data = None + columnwise_data = None + scale_inv = None + columnwise_scale_inv = None + amax = None + columnwise_amax = None + scale = None + scale_inv_offsets = None + columnwise_scale_inv_offsets = None + if no_quantization: + assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=dtype, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=dtype, device=device) + elif quantizer._get_compatible_recipe().mxfp8(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse buffer for MXFP8 - complex shape based on block scaling + # For grouped tensors, we need to calculate scale_inv size for all tensors + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_elements = math.prod(scale_inv_shape) + total_scale_elements += scale_elements + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + columnwise_scale_elements = math.prod(scale_inv_shape) + total_columnwise_scale_elements += columnwise_scale_elements + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + elif quantizer._get_compatible_recipe().delayed(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Amax buffer for delayed scaling - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().nvfp4(): + + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) + data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device) + # Scale inverse buffer for NVFP4 - complex shape based on block scaling + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + # Amax buffer - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) + columnwise_data = torch.empty( + (total_elements) // 2, dtype=torch.uint8, device=device + ) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + # Columnwise amax buffer - one per tensor + columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().float8_block_scaling(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - size depends on block configuration + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.float32, device=device + ) + elif quantizer._get_compatible_recipe().float8_current_scaling(): + # Current scaling - per-tensor scaling computed on the fly + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Scale and amax buffers for current scaling - one per tensor + scale = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + else: + raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}") + + grouped_tensor = GroupedTensor( + num_tensors=num_tensors, + shape=shape, + dtype=dtype, + quantizers=quantizers, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + logical_shape=logical_shape, + ) + + # grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() + return grouped_tensor + + def split_into_quantized_tensors( + self, + ) -> List[Union[QuantizedTensorStorage, torch.Tensor]]: + """ + Split the GroupedTensor into a list of `num_tensors` + quantized tensors based on the quantizer. No additional memory allocation is performed, + so the tensors returned are the same as the ones used to create the GroupedTensor. + + If quantizer is None, returns normal torch tensors. + If quantizer.internal is True, returns QuantizedTensorStorage. + Otherwise, returns QuantizedTensor. + + TODO(ksivaman): Block cases where any dims are varying. This is needed only + to expose the weights as separate parameters. + """ + + result = [] + + no_quantization = self.quantizers is None + + # Case 1: No quantization - return regular torch tensors + if no_quantization: + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + + # Get tensor data slice + if self.offsets is not None: + start_offset = self.offsets[i] + numel = tensor_shape[0] * tensor_shape[1] + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + else: + # All same shape case + numel = tensor_shape[0] * tensor_shape[1] + start_offset = i * numel + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + + return result + + # Case 2: Quantized tensors + recipe = self.quantizers[0]._get_compatible_recipe() + + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + numel = tensor_shape[0] * tensor_shape[1] + + # Get data offsets + if self.offsets is not None: + data_start = self.offsets[i] + data_end = data_start + numel + else: + # All same shape + data_start = i * numel + data_end = data_start + numel + + # Special shape handling for NVFP4. + nvfp4 = self.quantizers[0]._get_compatible_recipe().nvfp4() + if nvfp4: + data_start = data_start // 2 + data_end = data_end // 2 + + # Extract rowwise and columnwise data + rowwise_data = None + columnwise_data = None + + if self.has_data(): + if nvfp4: + rowwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4(tensor_shape) + else: + rowwise_tensor_shape = tensor_shape + rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) + + if self.has_columnwise_data(): + columnwise_tensor_shape = self.quantizers[0].get_columnwise_shape(tensor_shape) + if nvfp4: + columnwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4( + columnwise_tensor_shape + ) + columnwise_data = self.columnwise_data[data_start:data_end].view( + columnwise_tensor_shape + ) + + # MXFP8 format + if recipe.mxfp8(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Calculate expected scale shape for MXFP8 + scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + if self.quantizers[0].internal: + mxfp8_tensor_class = MXFP8TensorStorage + else: + mxfp8_tensor_class = MXFP8Tensor + tensor = mxfp8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + ) + result.append(tensor) + + # Delayed scaling or current scaling (both use Float8TensorStorage) + elif recipe.delayed() or recipe.float8_current_scaling(): + # Scale inverse - one per tensor + scale_inv = None + if self.scale_inv is not None: + scale_inv = self.scale_inv[i : i + 1] + + if self.quantizers[0].internal: + float8_tensor_class = Float8TensorStorage + else: + float8_tensor_class = Float8Tensor + + tensor = float8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + data=rowwise_data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + data_transpose=columnwise_data, + ) + result.append(tensor) + + # Float8 block scaling + elif recipe.float8_block_scaling(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Compute is_2D_scaled and data_format from quantizer attributes + is_2D_scaled = self.quantizers[0].block_scaling_dim == 2 + + if self.quantizers[0].internal: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage + else: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensor + + tensor = float8_blockwise_q_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + is_2D_scaled=is_2D_scaled, + ) + result.append(tensor) + + # NVFP4 format + elif recipe.nvfp4(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + amax_rowwise = None + amax_columnwise = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Extract amax - one per tensor + if self.amax is not None: + amax_rowwise = self.amax[i : i + 1] + + if self.columnwise_amax is not None: + amax_columnwise = self.columnwise_amax[i : i + 1] + + if self.quantizers[0].internal: + nvfp4_tensor_class = NVFP4TensorStorage + else: + nvfp4_tensor_class = NVFP4Tensor + + tensor = nvfp4_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.quantizers[0].dtype, + quantizer=self.quantizers[0], + with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + ) + result.append(tensor) + + else: + raise ValueError(f"Unsupported quantization recipe: {recipe}") + + return result + + @staticmethod + def create_and_quantize( + tensors: int, + quantizers: None | List[Quantizer], + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize given tensors into quantized tensors with underlying + storage allocated in a GroupedTensor. + """ + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=len(tensors), + shape=[t.shape for t in tensors], + quantizers=quantizers, + device=device, + dtype=dtype, + ) + + grouped_tensor.quantize(tensors, noop_flag=noop_flag) + + return grouped_tensor + + def quantize( + self, + tensors: List[torch.Tensor], + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize the GroupedTensor inplace. + """ + + quantized_tensors = self.split_into_quantized_tensors() + for i in range(self.num_tensors): + self.quantizers[0].update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + return quantized_tensors \ No newline at end of file From 46f2eb10fe8780b7d1de524758c58eda07a0e06b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:44:22 -0800 Subject: [PATCH 07/59] quantize mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention.py | 5 +- .../attention/dot_product_attention/utils.py | 132 ++++++++++++++++-- .../pytorch/tensor/storage/grouped_tensor.py | 36 ++--- 3 files changed, 141 insertions(+), 32 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 55553d30be..eb905d7b93 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -583,8 +583,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.custom(): - return + print(f"fp8_recipe: {fp8_recipe}") + # if fp8_recipe.custom(): + # return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 78083c0b0b..76f28f449d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -40,7 +40,8 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -2192,24 +2193,131 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): q,k,v = [x.contiguous() for x in [q, k, v]] # bshd_bshd_bhsd -> bhsd_bhsd_bhsd + # thd_thd_thd -> htd_htd_htd qkv_quantizer.optimize_for_gemm = True qkv_quantizer._internal = False dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") - dim_others = [i for i in range(len(v.shape)) if i != dim_s] - perm = [*dim_others, dim_s] - # perm = [*dim_others[:-1], dim_s, dim_others[-1]] - v = v.permute(*perm).contiguous() - qkv_layout = "bshd_bshd_bhds" - inv = [0] * len(perm) - for i, p in enumerate(perm): - inv[p] = i - # v = v.permute(*inv) - q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() + def permute_x(x): + dim_others = [i for i in range(len(x.shape)) if i != dim_s] + perm = [*dim_others[:-1], dim_s, dim_others[-1]] + x = x.permute(*perm).contiguous() + return x + q, k, v = [permute_x(x) for x in [q, k, v]] + # consider bhsd for now + batch_size, num_heads = q.shape[0], q.shape[1] + seq_len, head_dim = q.shape[-2], q.shape[-1] + num_tensors = 3 * batch_size * num_heads + # qkv = torch.cat([q, k, v], dim=0).reshape(num_tensors, seq_len, head_dim) + # qkv_list = [qkv[i] for i in range(num_tensors)] + # print(f">>>>>>>>>>>> num_tensors: {num_tensors}") + shapes = [(seq_len, head_dim) for _ in range(num_tensors)] + quantizers = [qkv_quantizer] * num_tensors + grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shapes, + quantizers=None, + device="cuda", + dtype=src_nominal_dtype, + ) + offset = 0 + for x in [q, k, v]: + numel = x.numel() + grouped_input.data[offset : offset + numel].copy_(x.reshape(-1)) + offset += numel + grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shapes, + quantizers=quantizers, + device="cuda", + ) + _ = tex.quantize_grouped(grouped_input, grouped_output) + print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") + print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") + # grouped_output_list = [grouped_output[i] for i in range(num_tensors)] + # q_fp8, k_fp8, v_fp8 = grouped_output_list[0:32], grouped_output_list[32:64], grouped_output_list[64:] + # grouped_output.num_tensors = 3 + # grouped_output.quantizers = [qkv_quantizer] * 3 + # grouped_output.shape = [(batch_size * num_heads * seq_len, head_dim) for _ in range(3)] + # grouped_output.dtype = src_nominal_dtype + # grouped_output.data = torch.cat([x.reshape(-1) for x in [q, k, v]], dim=0) + # q_fp8, k_fp8, v_fp8 = grouped_output.split_into_quantized_tensors() + + def split_qkv(grouped_tensor, num_tensors): + rowwise_shape = q.shape + rowwise_scale_inv_shape = (*q.shape[:-1], q.shape[-1]//32) + columnwise_shape = q.shape + columnwise_scale_inv_shape = (*q.shape[:-2], q.shape[-2]//32, q.shape[-1]) + rowwise_data = grouped_tensor.data.view(num_tensors, *rowwise_shape).split([1] * num_tensors) + rowwise_scale_inv = grouped_tensor.scale_inv.view(num_tensors, *rowwise_scale_inv_shape).split([1] * num_tensors) + columnwise_data = grouped_tensor.columnwise_data.view(num_tensors, *columnwise_shape).split([1] * num_tensors) + columnwise_scale_inv = grouped_tensor.columnwise_scale_inv.view(num_tensors, *columnwise_scale_inv_shape).split([1] * num_tensors) + print(f">>>>>>>>>>>> rowwise_data: {len(rowwise_data)}, rowwise_scale_inv: {len(rowwise_scale_inv)}, columnwise_data: {len(columnwise_data)}, columnwise_scale_inv: {len(columnwise_scale_inv)}") + return [MXFP8Tensor( + shape=q.shape, + dtype=q.dtype, + rowwise_data=rowwise_data[i].squeeze(0), + rowwise_scale_inv=rowwise_scale_inv[i].squeeze(0), + columnwise_data=columnwise_data[i].squeeze(0), + columnwise_scale_inv=columnwise_scale_inv[i].squeeze(0), + fp8_dtype=qkv_quantizer.dtype, + quantizer=qkv_quantizer, + with_gemm_swizzled_scales=qkv_quantizer.optimize_for_gemm, + ) for i in range(num_tensors)] + q_fp8, k_fp8, v_fp8 = split_qkv(grouped_output, 3) + + print(f">>>>>>>>>>>> q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") + print(f">>>>>>>>>>>> rowwise_data: q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f">>>>>>>>>>>> rowwise_scale_inv: q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f">>>>>>>>>>>> columnwise_data: q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f">>>>>>>>>>>> columnwise_scale_inv: q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + + # print(f">>>>>>>>>>>> grouped_output: {len(grouped_output) if grouped_output is not None else None}") + + # qkv_mxfp8 = grouped_tensor.quantize(qkv_list) + # print(f">>>>>>>>>>>> qkv_mxfp8: {type(qkv_mxfp8)}") + # qkv_mxfp8_list = [qkv_mxfp8[i] for i in range(num_tensors)] + # print(f">>>>>>>>>>>> qkv_mxfp8: {qkv_mxfp8}") + # print(f">>>>>>>>>>>> qkv_mxfp8: {len(qkv_mxfp8_list)}") + # print(f">>>>>>>>>>>> qkv_mxfp8.shape: {qkv_mxfp8.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_data.shape: {qkv_mxfp8._rowwise_data.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_scale_inv.shape: {qkv_mxfp8._rowwise_scale_inv.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_data.shape: {qkv_mxfp8._columnwise_data.shape}") + # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_scale_inv.shape: {qkv_mxfp8._columnwise_scale_inv.shape}") + # q_fp8, k_fp8, v_fp8 = qkv_mxfp8[0::batch_size * num_heads], qkv_mxfp8[batch_size:2*batch_size], qkv_mxfp8[2*batch_size:] + + # q_fp8, k_fp8, v_fp8 = qkv_mxfp8.split_into_quantized_tensors() + + print(f"q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + + # q_fp8_rowwise, k_fp8_rowwise, v_fp8_rowwise = [x._rowwise_data for x in qkv_mxfp8] + # q_fp8_columnwise, k_fp8_columnwise, v_fp8_columnwise = [x._columnwise_data for x in qkv_mxfp8] + # q_fp8, k_fp8, v_fp8 = q_fp8_rowwise, k_fp8_rowwise, v_fp8_columnwise + + # dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") + # dim_others = [i for i in range(len(v.shape)) if i != dim_s] + # perm = [*dim_others, dim_s] + # # perm = [*dim_others[:-1], dim_s, dim_others[-1]] + # v = v.permute(*perm).contiguous() + + qkv_layout = "bhsd_bhsd_bhsd" + + # inv = [0] * len(perm) + # for i, p in enumerate(perm): + # inv[p] = i + # # v = v.permute(*inv) + + # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + # # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() + # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") return q_fp8, k_fp8, v_fp8, qkv_layout match qkv_group: diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index ad85a448e6..522c25f370 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -90,23 +90,23 @@ def __init__( offsets: Vector of integer offsets for each tensor. logical_shape: 2D tuple representing conceptual shape """ - print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") - print(f">>>>>>>>>>>> shape: {shape}") - print(f">>>>>>>>>>>> dtype: {dtype}") - print(f">>>>>>>>>>>> data: {data.shape}") - print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") - print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") - print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") - print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") - print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") - print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") - print(f">>>>>>>>>>>> first_dims: {first_dims}") - print(f">>>>>>>>>>>> last_dims: {last_dims}") - print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets}") - print(f">>>>>>>>>>>> offsets: {offsets}") - print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets}") - print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets}") - print(f">>>>>>>>>>>> logical_shape: {logical_shape}") + # print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") + # print(f">>>>>>>>>>>> shape: {shape}") + # print(f">>>>>>>>>>>> dtype: {dtype}") + # print(f">>>>>>>>>>>> data: {data.shape}") + # print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") + # print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") + # print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") + # print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") + # print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") + # print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") + # print(f">>>>>>>>>>>> first_dims: {first_dims.shape}") + # print(f">>>>>>>>>>>> last_dims: {last_dims.shape}") + # print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets.shape}") + # print(f">>>>>>>>>>>> offsets: {offsets.shape}") + # print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets.shape}") + # print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets.shape}") + # print(f">>>>>>>>>>>> logical_shape: {logical_shape.shape}") print(f">>>>>>>>>>>> num_tensors: {num_tensors}") self.num_tensors = num_tensors @@ -434,7 +434,7 @@ def make_grouped_tensor( logical_shape = (logical_first_dim, logical_last_dim) quantizer = quantizers[0] if isinstance(quantizers, list) else quantizers - print(f">>>>>>>>>>>>> quantizers: {quantizers}") + # print(f">>>>>>>>>>>>> quantizers: {quantizers}") print(f">>>>>>>>>>>>> quantizer: {quantizer}") no_quantization = quantizer is None From e86207c7d7afe0c982b33633bb84c55c5fb01899 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 9 Feb 2026 19:53:25 -0800 Subject: [PATCH 08/59] fix shapes/strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 21 ++- .../cast/mxfp8/group_quantize_mxfp8.cuh | 5 +- .../common/cast/mxfp8/quantize_mxfp8.cuh | 5 +- .../common/fused_attn/fused_attn.cpp | 70 ++++---- .../common/fused_attn/fused_attn_fp8.cu | 13 +- transformer_engine/common/fused_attn/utils.cu | 36 ++-- .../include/transformer_engine/fused_attn.h | 10 +- .../common/util/pybind_helper.h | 4 +- .../dot_product_attention/backends.py | 14 ++ .../attention/dot_product_attention/utils.py | 130 ++++---------- .../pytorch/cpp_extensions/fused_attn.py | 4 +- .../pytorch/csrc/type_converters.cpp | 15 +- .../pytorch/tensor/mxfp8_tensor.py | 2 + .../pytorch/tensor/storage/grouped_tensor.py | 162 ++++++++++-------- 14 files changed, 229 insertions(+), 262 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 0301f77ae8..5602114143 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2157,16 +2157,19 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) + # print(f">>>>>> fused_attn_bwd_fp8: {fused_attn_bwd_fp8} {is_training}") + # torch.save(fused_attn_fwd_fp8, "fused_attn_fwd_fp8.pt") - # os.environ["NVTE_FLASH_ATTN"] = "0" - # os.environ["NVTE_FUSED_ATTN"] = "1" - # os.environ["NVTE_UNFUSED_ATTN"] = "0" - # if config.dropout_p == 0.0: - # # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - # fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - # dtype, config, False, qkv_layout, is_training, fp8_recipe - # ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training, fp8_recipe + ) + # torch.save(fused_attn_fwd_f16, "fused_attn_fwd_f16.pt") atol = 5e-1 rtol = 5e-2 diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 35e605067d..ea81e6c516 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -244,7 +244,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { -printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); +// printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -935,16 +935,19 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations OType, true, true, WITH_GEMM_SWIZZLED_SCALES>; switch (scaling_type) { case ScalingType::ROWWISE: { + printf(">>>>>>>>>>>> grouped: ScalingType::ROWWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::COLWISE: { + printf(">>>>>>>>>>>> grouped: ScalingType::COLWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::BIDIMENSIONAL: { + printf(">>>>>>>>>>>> grouped: ScalingType::BIDIMENSIONAL\n"); kernel = group_quantize_mxfp8_kernel; break; diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a3e7db94d1..82bf497a3b 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -55,7 +55,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { -printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); +// printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -777,6 +777,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, switch (scaling_type) { case ScalingType::ROWWISE: { + printf(">>>>>>>>>>>> non-grouped: ScalingType::ROWWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -792,6 +793,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { + printf(">>>>>>>>>>>> non-grouped: ScalingType::COLWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -807,6 +809,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { + printf(">>>>>>>>>>>> non-grouped: ScalingType::BIDIMENSIONAL\n"); auto kernel = quantize_mxfp8_kernel; diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 61a8d61635..3ed540b8f2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -117,8 +117,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: - return NVTE_QKV_Layout_Group::NVTE_HD_HD_DS; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_SD_SD_SD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -159,8 +159,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: - return NVTE_QKV_Format::NVTE_BHDS; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -180,8 +180,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHDS: - return NVTE_QKV_Format::NVTE_BHDS; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -201,8 +201,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; - case NVTE_QKV_Format::NVTE_BHDS: - return NVTE_QKV_Format::NVTE_BHDS; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -248,29 +248,29 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && // 8.9: t3hd, max_s=512, d=64, padding - // ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - // qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - // max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - // (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - // max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - // (cudnn_runtime_version >= 90700 && - // // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // // sm90: fwd d<=256, bwd d=128 only - // // sm100: fwd d<=128, bwd d<=128 - // ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - // (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - // (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - // head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - // (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - // attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHDS) && + ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -1151,8 +1151,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; @@ -1165,6 +1163,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -1277,8 +1277,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; @@ -1291,6 +1289,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index bf4f019a67..f7698da5c3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1781,12 +1781,13 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); // need to double check + NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); @@ -2527,10 +2528,10 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; devPtrDescaleK = input_K->scale_inv.dptr; - devPtrV = input_V->data.dptr; - devPtrDescaleV = input_V->scale_inv.dptr; - // devPtrV = input_V->columnwise_data.dptr; - // devPtrDescaleV = input_V->columnwise_scale_inv.dptr; + // devPtrV = input_V->data.dptr; + // devPtrDescaleV = input_V->scale_inv.dptr; + devPtrV = input_V->columnwise_data.dptr; + devPtrDescaleV = input_V->columnwise_scale_inv.dptr; devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; } else { @@ -2589,7 +2590,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHDS)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 94a495153e..3ea40126cc 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,32 +293,24 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS: + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { - strideA[batch_dim_idx] = s_q * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_dim_idx] = d; strideA[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_dim_idx] = d; strideA[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d * s_kv; - strideA[seqlen_dim_idx] = 1; - strideA[hidden_dim_idx] = s_kv; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_transpose_dim_idx] = 1; - strideA[hidden_transpose_dim_idx] = h * d; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { - strideA[batch_dim_idx] = s_kv * h * d; - strideA[head_dim_idx] = d * s_kv; - strideA[seqlen_transpose_dim_idx] = s_kv; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_transpose_dim_idx] = d; strideA[hidden_transpose_dim_idx] = 1; } break; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index bc97d2a853..204d8f3d5a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,7 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ - NVTE_BSHD_BSHD_BHDS = 25, /*!< BSHD_BSHD_BHDS layout */ + NVTE_BHSD_BHSD_BHSD = 25, /*!< BHSD_BHSD_BHSD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -71,8 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, - /*! BSHD_BSHD_BHDS QKV layouts, e.g. BSHD_BSHD_BHDS */ - NVTE_HD_HD_DS = 6, + /*! SD_SD_SD QKV layouts, e.g. BHSD_BHSD_BHSD */ + NVTE_SD_SD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -93,8 +93,8 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, - /*! BSHD_BSHD_BHSD QKV format, e.g. BSHD_BSHD_BHSD */ - NVTE_BHDS = 7, + /*! BHSD QKV format, e.g. BHSD_BHSD_BHSD */ + NVTE_BHSD = 7, }; /*! \enum NVTE_Bias_Type diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b81c488005..96e6803ec5 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -49,7 +49,7 @@ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ - .value("NVTE_BHDS", NVTE_QKV_Format::NVTE_BHDS); \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -76,7 +76,7 @@ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ - .value("NVTE_BSHD_BSHD_BHDS", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BHDS); \ + .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 01c26a9728..5cc23eabd8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1217,6 +1217,9 @@ def forward( # FP8 attention: torch.float16 or torch.bfloat16 out_nominal_dtype = q.dtype + # save original qkv_layout + original_qkv_layout = qkv_layout + max_logit = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] @@ -1276,6 +1279,17 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) + if original_qkv_layout != qkv_layout: + print(f">>>>>>>>>>>> original_qkv_layout: {original_qkv_layout}") + print(f">>>>>>>>>>>> qkv_layout: {qkv_layout}") + print(f">>>>>>>>>>>> out_.shape: {out_.shape}") + original_qkv_format = original_qkv_layout.split("_")[0] + new_qkv_format = qkv_layout.split("_")[0] + perm = [] + for i in new_qkv_format: + perm.append(original_qkv_format.find(i)) + out_ = out_.permute(*perm).contiguous() + print(f">>>>>>>>>>>> out_.shape permuted: {out_.shape}") # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 76f28f449d..c5ba652c28 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2191,11 +2191,13 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): if isinstance(qkv_quantizer, MXFP8Quantizer): # bs3hd -> bshd_bshd_bhsd q,k,v = [x.contiguous() for x in [q, k, v]] + print(f">>>>>>>>>>>> Contiguous shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # bshd_bshd_bhsd -> bhsd_bhsd_bhsd # thd_thd_thd -> htd_htd_htd qkv_quantizer.optimize_for_gemm = True - qkv_quantizer._internal = False + qkv_quantizer.internal = False + print(f">>>>>>>>>>>> qkv_quantizer.internal: {qkv_quantizer.internal}") dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") def permute_x(x): dim_others = [i for i in range(len(x.shape)) if i != dim_s] @@ -2203,102 +2205,34 @@ def permute_x(x): x = x.permute(*perm).contiguous() return x q, k, v = [permute_x(x) for x in [q, k, v]] - # consider bhsd for now - batch_size, num_heads = q.shape[0], q.shape[1] - seq_len, head_dim = q.shape[-2], q.shape[-1] - num_tensors = 3 * batch_size * num_heads - # qkv = torch.cat([q, k, v], dim=0).reshape(num_tensors, seq_len, head_dim) - # qkv_list = [qkv[i] for i in range(num_tensors)] - # print(f">>>>>>>>>>>> num_tensors: {num_tensors}") - shapes = [(seq_len, head_dim) for _ in range(num_tensors)] - quantizers = [qkv_quantizer] * num_tensors - grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shape=shapes, - quantizers=None, - device="cuda", - dtype=src_nominal_dtype, - ) - offset = 0 - for x in [q, k, v]: - numel = x.numel() - grouped_input.data[offset : offset + numel].copy_(x.reshape(-1)) - offset += numel - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shape=shapes, - quantizers=quantizers, - device="cuda", - ) - _ = tex.quantize_grouped(grouped_input, grouped_output) - print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") - # grouped_output_list = [grouped_output[i] for i in range(num_tensors)] - # q_fp8, k_fp8, v_fp8 = grouped_output_list[0:32], grouped_output_list[32:64], grouped_output_list[64:] - # grouped_output.num_tensors = 3 - # grouped_output.quantizers = [qkv_quantizer] * 3 - # grouped_output.shape = [(batch_size * num_heads * seq_len, head_dim) for _ in range(3)] - # grouped_output.dtype = src_nominal_dtype - # grouped_output.data = torch.cat([x.reshape(-1) for x in [q, k, v]], dim=0) - # q_fp8, k_fp8, v_fp8 = grouped_output.split_into_quantized_tensors() - - def split_qkv(grouped_tensor, num_tensors): - rowwise_shape = q.shape - rowwise_scale_inv_shape = (*q.shape[:-1], q.shape[-1]//32) - columnwise_shape = q.shape - columnwise_scale_inv_shape = (*q.shape[:-2], q.shape[-2]//32, q.shape[-1]) - rowwise_data = grouped_tensor.data.view(num_tensors, *rowwise_shape).split([1] * num_tensors) - rowwise_scale_inv = grouped_tensor.scale_inv.view(num_tensors, *rowwise_scale_inv_shape).split([1] * num_tensors) - columnwise_data = grouped_tensor.columnwise_data.view(num_tensors, *columnwise_shape).split([1] * num_tensors) - columnwise_scale_inv = grouped_tensor.columnwise_scale_inv.view(num_tensors, *columnwise_scale_inv_shape).split([1] * num_tensors) - print(f">>>>>>>>>>>> rowwise_data: {len(rowwise_data)}, rowwise_scale_inv: {len(rowwise_scale_inv)}, columnwise_data: {len(columnwise_data)}, columnwise_scale_inv: {len(columnwise_scale_inv)}") - return [MXFP8Tensor( - shape=q.shape, - dtype=q.dtype, - rowwise_data=rowwise_data[i].squeeze(0), - rowwise_scale_inv=rowwise_scale_inv[i].squeeze(0), - columnwise_data=columnwise_data[i].squeeze(0), - columnwise_scale_inv=columnwise_scale_inv[i].squeeze(0), - fp8_dtype=qkv_quantizer.dtype, - quantizer=qkv_quantizer, - with_gemm_swizzled_scales=qkv_quantizer.optimize_for_gemm, - ) for i in range(num_tensors)] - q_fp8, k_fp8, v_fp8 = split_qkv(grouped_output, 3) + print(f">>>>>>>>>>>> Permuted shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") + + original_shapes = [q.shape, k.shape, v.shape] + b, h_q, s_q, d_qk = q.shape + _, h_kv, s_kv, d_kv = v.shape + assert k.shape == (b, h_kv, s_kv, d_qk) + assert s_q % 128 == 0 + assert s_kv % 128 == 0 + assert d_qk % 32 == 0 + assert d_kv % 32 == 0 + q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] + print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") + # consider bhsd for now + grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) + print(f">>>>>>>>>>>> grouped_tensor: {type(grouped_tensor)}") + print(f">>>>>>>>>>>> grouped_tensor.quantized_tensors: {type(grouped_tensor.quantized_tensors)}") + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") + q_fp8, k_fp8, v_fp8 = [x.view(*original_shapes[i]) for i, x in enumerate([q_fp8, k_fp8, v_fp8])] print(f">>>>>>>>>>>> q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") - print(f">>>>>>>>>>>> rowwise_data: q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f">>>>>>>>>>>> rowwise_scale_inv: q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f">>>>>>>>>>>> columnwise_data: q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f">>>>>>>>>>>> columnwise_scale_inv: q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - - # print(f">>>>>>>>>>>> grouped_output: {len(grouped_output) if grouped_output is not None else None}") - - # qkv_mxfp8 = grouped_tensor.quantize(qkv_list) - # print(f">>>>>>>>>>>> qkv_mxfp8: {type(qkv_mxfp8)}") - # qkv_mxfp8_list = [qkv_mxfp8[i] for i in range(num_tensors)] - # print(f">>>>>>>>>>>> qkv_mxfp8: {qkv_mxfp8}") - # print(f">>>>>>>>>>>> qkv_mxfp8: {len(qkv_mxfp8_list)}") - # print(f">>>>>>>>>>>> qkv_mxfp8.shape: {qkv_mxfp8.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_data.shape: {qkv_mxfp8._rowwise_data.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._rowwise_scale_inv.shape: {qkv_mxfp8._rowwise_scale_inv.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_data.shape: {qkv_mxfp8._columnwise_data.shape}") - # print(f">>>>>>>>>>>> qkv_mxfp8._columnwise_scale_inv.shape: {qkv_mxfp8._columnwise_scale_inv.shape}") - # q_fp8, k_fp8, v_fp8 = qkv_mxfp8[0::batch_size * num_heads], qkv_mxfp8[batch_size:2*batch_size], qkv_mxfp8[2*batch_size:] - - # q_fp8, k_fp8, v_fp8 = qkv_mxfp8.split_into_quantized_tensors() - - print(f"q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") - print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - - # q_fp8_rowwise, k_fp8_rowwise, v_fp8_rowwise = [x._rowwise_data for x in qkv_mxfp8] - # q_fp8_columnwise, k_fp8_columnwise, v_fp8_columnwise = [x._columnwise_data for x in qkv_mxfp8] - # q_fp8, k_fp8, v_fp8 = q_fp8_rowwise, k_fp8_rowwise, v_fp8_columnwise - + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") + print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") # dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") # dim_others = [i for i in range(len(v.shape)) if i != dim_s] # perm = [*dim_others, dim_s] @@ -2312,12 +2246,6 @@ def split_qkv(grouped_tensor, num_tensors): # inv[p] = i # # v = v.permute(*inv) - # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] - # # v_fp8._rowwise_data = v_fp8._rowwise_data.permute(*inv).contiguous() - # print(f"q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - # print(f"q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - # print(f"q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - # print(f"q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") return q_fp8, k_fp8, v_fp8, qkv_layout match qkv_group: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 2748228b42..b4811eb4f6 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,7 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, - "bshd_bshd_bhds": NVTE_QKV_Format.NVTE_BHDS, + "bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -71,7 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, - "bshd_bshd_bhds": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BHDS, + "bhsd_bhsd_bhsd": NVTE_QKV_Layout.NVTE_BHSD_BHSD_BHSD, } AttnBiasType = { diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 8ab8dc1d48..c17be6c855 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -207,15 +207,12 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { DType quantizer_dtype = DType::kNumTypes; NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; bool with_gemm_swizzled_scales = false; - if (!tensor.attr("quantizers").is_none()) { - const auto quantizers = tensor.attr("quantizers").cast(); - quantizer = quantizers[0]; - if (!quantizers.empty() && !quantizer.is_none()) { - scaling_mode = ScalingModeFromQuantizer(quantizer); - quantizer_dtype = quantizer.attr("dtype").cast(); - with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); - printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); - } + if (!tensor.attr("quantizer").is_none()) { + quantizer = tensor.attr("quantizer").cast(); + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); + printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); } auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index a283b43908..e4c658ed58 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -75,6 +75,8 @@ def update_quantized( src = src.contiguous() # Launch cast kernel + print(f">>>>>>>>>>>> src: {src.shape}") + print(f">>>>>>>>>>>> dst: {dst.shape}") tex.quantize(src, self, dst, noop_flag) # Update FP8 dtype diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 522c25f370..9771d61df8 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -8,7 +8,8 @@ import math import torch - +import transformer_engine +import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor @@ -52,8 +53,8 @@ class GroupedTensor: def __init__( self, num_tensors: int, - shape: List[Tuple[int, int]], - quantizers: Optional[List[Quantizer]] = None, + shapes: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, columnwise_data: Optional[torch.Tensor] = None, @@ -75,8 +76,8 @@ def __init__( Args: num_tensors: Number of tensors in the group - shape: 2D shape of each tensor (len num_tensors) - quantizers: List of Quantizers for the grouped tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for the grouped tensor data: Row-wise data buffer (1D flattened) columnwise_data: Column-wise data buffer (1D flattened) scale_inv: Row-wise scale inverse buffer @@ -110,8 +111,8 @@ def __init__( print(f">>>>>>>>>>>> num_tensors: {num_tensors}") self.num_tensors = num_tensors - self.quantizers = quantizers - self.shape = shape + self.quantizer = quantizer + self.shapes = shapes self.dtype = ( dtype if dtype is not None else torch.float32 ) # Default to float32 if not provided @@ -276,7 +277,7 @@ def clear(self) -> None: self.tensor_offsets = None self.logical_shape = (0, 0) self.num_tensors = 0 - self.quantizers = None + self.quantizer = None self.quantized_tensors = None self.offsets = None self.scale_inv_offsets = None @@ -286,7 +287,7 @@ def __repr__(self) -> str: """String representation of the GroupedTensor.""" return ( f"GroupedTensor(num_tensors={self.num_tensors}, " - f"shape={self.shape}, " + f"shapes={self.shapes}, " f"logical_shape={self.logical_shape}, " f"dtype={self.get_dtype()})" ) @@ -312,8 +313,8 @@ def __str__(self) -> str: @staticmethod def make_grouped_tensor_with_shapes( num_tensors: int, - shape: List[Tuple[int, int]], - quantizers: Optional[List[Quantizer]] = None, + shapes: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> GroupedTensor: @@ -322,8 +323,8 @@ def make_grouped_tensor_with_shapes( Args: num_tensors: Number of tensors - shape: 2D shape of each tensor (len num_tensors) - quantizers: List of Quantizers for each tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for the grouped tensor device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -332,16 +333,16 @@ def make_grouped_tensor_with_shapes( """ # First dim - first_dim_list = [s[0] for s in shape] + first_dim_list = [s[0] for s in shapes] uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) logical_first_dim = sum(first_dim_list) if uniform_first_dim: first_dims = None else: - first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + first_dims = torch.tensor([s[0] for s in shapes], dtype=torch.int64, device=device) # Last dim - last_dim_list = [s[1] for s in shape] + last_dim_list = [s[1] for s in shapes] logical_last_dim = last_dim_list[0] assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" @@ -351,7 +352,7 @@ def make_grouped_tensor_with_shapes( last_dims=None, logical_first_dim=logical_first_dim, logical_last_dim=logical_last_dim, - quantizers=quantizers, + quantizer=quantizer, device=device, dtype=dtype, ) @@ -363,7 +364,7 @@ def make_grouped_tensor( last_dims: Optional[torch.tensor], logical_first_dim: int, logical_last_dim: int, - quantizers: Optional[List[Quantizer]] = None, + quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> GroupedTensor: @@ -376,7 +377,7 @@ def make_grouped_tensor( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) logical_first_dim: Logical first dimension logical_last_dim: Logical last dimension - quantizers: List of Quantizers for each tensor + quantizer: Quantizer for the grouped tensor Used to figure out the recipe and what to allocate. device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -405,7 +406,7 @@ def make_grouped_tensor( # Calculate tensor offsets (cumulative element offsets) tensor_offsets = None offsets = None - shape = [] + shapes = [] if not all_same_first: # Need explicit offsets for non-uniform shapes # Offsets are based on number of elements and not pointers. @@ -421,21 +422,18 @@ def make_grouped_tensor( offsets = tensor_offsets.tolist() first_dims_list = first_dims.tolist() for i in range(num_tensors): - shape.append((first_dims_list[i], logical_last_dim)) + shapes.append((first_dims_list[i], logical_last_dim)) else: offsets = [ i * logical_first_dim * logical_last_dim // num_tensors for i in range(num_tensors + 1) ] for i in range(num_tensors): - shape.append((logical_first_dim // num_tensors, logical_last_dim)) + shapes.append((logical_first_dim // num_tensors, logical_last_dim)) # Calculate logical shape based logical_shape = (logical_first_dim, logical_last_dim) - quantizer = quantizers[0] if isinstance(quantizers, list) else quantizers - # print(f">>>>>>>>>>>>> quantizers: {quantizers}") - print(f">>>>>>>>>>>>> quantizer: {quantizer}") no_quantization = quantizer is None rowwise_usage = quantizer.rowwise_usage if not no_quantization else True @@ -470,7 +468,7 @@ def make_grouped_tensor( # For grouped tensors, we need to calculate scale_inv size for all tensors total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) scale_elements = math.prod(scale_inv_shape) total_scale_elements += scale_elements @@ -484,7 +482,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements @@ -538,7 +536,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) if i < num_tensors - 1: @@ -569,7 +567,7 @@ def make_grouped_tensor( # Columnwise scale inverse total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) if i < num_tensors - 1: @@ -603,9 +601,9 @@ def make_grouped_tensor( grouped_tensor = GroupedTensor( num_tensors=num_tensors, - shape=shape, + shapes=shapes, dtype=dtype, - quantizers=quantizers, + quantizer=quantizer, data=data, columnwise_data=columnwise_data, scale_inv=scale_inv, @@ -643,13 +641,13 @@ def split_into_quantized_tensors( result = [] - no_quantization = self.quantizers is None + no_quantization = self.quantizer is None # Case 1: No quantization - return regular torch tensors if no_quantization: for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.shapes[i] # Get tensor data slice if self.offsets is not None: @@ -687,11 +685,11 @@ def split_into_quantized_tensors( return result # Case 2: Quantized tensors - recipe = self.quantizers[0]._get_compatible_recipe() + recipe = self.quantizer._get_compatible_recipe() for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.shapes[i] numel = tensor_shape[0] * tensor_shape[1] # Get data offsets @@ -704,7 +702,7 @@ def split_into_quantized_tensors( data_end = data_start + numel # Special shape handling for NVFP4. - nvfp4 = self.quantizers[0]._get_compatible_recipe().nvfp4() + nvfp4 = self.quantizer._get_compatible_recipe().nvfp4() if nvfp4: data_start = data_start // 2 data_end = data_end // 2 @@ -715,15 +713,15 @@ def split_into_quantized_tensors( if self.has_data(): if nvfp4: - rowwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4(tensor_shape) + rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape) else: rowwise_tensor_shape = tensor_shape rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) if self.has_columnwise_data(): - columnwise_tensor_shape = self.quantizers[0].get_columnwise_shape(tensor_shape) + columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape) if nvfp4: - columnwise_tensor_shape = self.quantizers[0].convert_shape_for_fp4( + columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4( columnwise_tensor_shape ) columnwise_data = self.columnwise_data[data_start:data_end].view( @@ -744,7 +742,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv.numel() # Calculate expected scale shape for MXFP8 - scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -757,12 +755,12 @@ def split_into_quantized_tensors( else: cscale_end = self.columnwise_scale_inv.numel() - cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) - if self.quantizers[0].internal: + if self.quantizer.internal: mxfp8_tensor_class = MXFP8TensorStorage else: mxfp8_tensor_class = MXFP8Tensor @@ -773,9 +771,9 @@ def split_into_quantized_tensors( rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], - with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, ) result.append(tensor) @@ -786,7 +784,7 @@ def split_into_quantized_tensors( if self.scale_inv is not None: scale_inv = self.scale_inv[i : i + 1] - if self.quantizers[0].internal: + if self.quantizer.internal: float8_tensor_class = Float8TensorStorage else: float8_tensor_class = Float8Tensor @@ -796,8 +794,8 @@ def split_into_quantized_tensors( dtype=self.dtype, data=rowwise_data, fp8_scale_inv=scale_inv, - fp8_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, data_transpose=columnwise_data, ) result.append(tensor) @@ -816,7 +814,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv.numel() # Get scale shape from quantizer - scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -830,15 +828,15 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv.numel() # Get columnwise scale shape from quantizer - cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) # Compute is_2D_scaled and data_format from quantizer attributes - is_2D_scaled = self.quantizers[0].block_scaling_dim == 2 + is_2D_scaled = self.quantizer.block_scaling_dim == 2 - if self.quantizers[0].internal: + if self.quantizer.internal: float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage else: float8_blockwise_q_tensor_class = Float8BlockwiseQTensor @@ -850,8 +848,8 @@ def split_into_quantized_tensors( rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, is_2D_scaled=is_2D_scaled, ) result.append(tensor) @@ -872,7 +870,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv.numel() # Get scale shape from quantizer - scale_shape = self.quantizers[0].get_scale_shape(tensor_shape, False) + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -886,7 +884,7 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv.numel() # Get columnwise scale shape from quantizer - cscale_shape = self.quantizers[0].get_scale_shape(tensor_shape, True) + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) @@ -898,7 +896,7 @@ def split_into_quantized_tensors( if self.columnwise_amax is not None: amax_columnwise = self.columnwise_amax[i : i + 1] - if self.quantizers[0].internal: + if self.quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage else: nvfp4_tensor_class = NVFP4Tensor @@ -912,9 +910,9 @@ def split_into_quantized_tensors( columnwise_scale_inv=columnwise_scale_inv, amax_rowwise=amax_rowwise, amax_columnwise=amax_columnwise, - fp4_dtype=self.quantizers[0].dtype, - quantizer=self.quantizers[0], - with_gemm_swizzled_scales=self.quantizers[0].optimize_for_gemm, + fp4_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, ) result.append(tensor) @@ -926,7 +924,7 @@ def split_into_quantized_tensors( @staticmethod def create_and_quantize( tensors: int, - quantizers: None | List[Quantizer], + quantizer: None | Quantizer, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -937,17 +935,41 @@ def create_and_quantize( storage allocated in a GroupedTensor. """ - grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=len(tensors), + shapes=[t.shape for t in tensors], + quantizer=None, + device=device, + dtype=tensors[0].dtype, + ) + + offset = 0 + for tensor in tensors: + numel = tensor.numel() + grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=len(tensors), - shape=[t.shape for t in tensors], - quantizers=quantizers, + shapes=[t.shape for t in tensors], + quantizer=quantizer, device=device, dtype=dtype, ) - grouped_tensor.quantize(tensors, noop_flag=noop_flag) + _ = tex.quantize_grouped(grouped_input, grouped_output) + grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() - return grouped_tensor + # grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + # num_tensors=len(tensors), + # shapes=[t.shape for t in tensors], + # quantizer=None, + # device=device, + # dtype=tensors[0].dtype, + # ) + # grouped_tensor.quantize(tensors, noop_flag=noop_flag) + + return grouped_output def quantize( self, @@ -958,7 +980,9 @@ def quantize( Quantize the GroupedTensor inplace. """ - quantized_tensors = self.split_into_quantized_tensors() + self.quantized_tensors = self.split_into_quantized_tensors() + print(f">>>>>>>>>>>> tensors[0]: {type(tensors[0])}") + print(f">>>>>>>>>>>> quantized_tensors[0]: {type(self.quantized_tensors[0])}") for i in range(self.num_tensors): - self.quantizers[0].update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) - return quantized_tensors \ No newline at end of file + self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) + return self.quantized_tensors \ No newline at end of file From 4e854d523d056e5348d4ec5d122c0936aa00eb8d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 11 Feb 2026 18:50:31 -0800 Subject: [PATCH 09/59] fix unfused; clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 23 ++-- tests/pytorch/test_grouped_tensor.py | 110 ++++++++--------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 9 +- .../common/cast/mxfp8/quantize_mxfp8.cuh | 7 +- .../common/fused_attn/fused_attn.cpp | 23 ++-- .../common/fused_attn/fused_attn_fp8.cu | 18 +-- transformer_engine/common/recipe/__init__.py | 2 +- .../dot_product_attention/backends.py | 112 ++++++++++-------- .../dot_product_attention/context_parallel.py | 2 +- .../dot_product_attention.py | 5 +- .../attention/dot_product_attention/utils.py | 74 ++++++------ .../pytorch/cpp_extensions/fused_attn.py | 1 - .../pytorch/tensor/mxfp8_tensor.py | 5 +- .../pytorch/tensor/storage/grouped_tensor.py | 37 +----- 14 files changed, 183 insertions(+), 245 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5602114143..74deeceed2 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2139,15 +2139,15 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal dtype, config, True, qkv_layout, is_training, fp8_recipe ) - # if unfused_attn_supported: - # os.environ["NVTE_FLASH_ATTN"] = "0" - # os.environ["NVTE_FUSED_ATTN"] = "0" - # os.environ["NVTE_UNFUSED_ATTN"] = "1" - # _attention_backends["backend_selection_requires_update"] = True - # logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") - # unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - # dtype, config, True, qkv_layout, is_training, fp8_recipe - # ) + if unfused_attn_supported: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe + ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" @@ -2157,8 +2157,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) - # print(f">>>>>> fused_attn_bwd_fp8: {fused_attn_bwd_fp8} {is_training}") - # torch.save(fused_attn_fwd_fp8, "fused_attn_fwd_fp8.pt") os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" @@ -2169,7 +2167,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( dtype, config, False, qkv_layout, is_training, fp8_recipe ) - # torch.save(fused_attn_fwd_f16, "fused_attn_fwd_f16.pt") atol = 5e-1 rtol = 5e-2 @@ -2188,7 +2185,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if False: #unfused_attn_supported: + if unfused_attn_supported: logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 964c2d8e97..f0b2c35c0a 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -121,11 +121,11 @@ def setup_class(cls) -> None: def test_basic_construction_all_same_shape(self) -> None: """Test GroupedTensor construction with all tensors having same shape""" num_tensors = 4 - shape = [(256, 512) for _ in range(num_tensors)] + shapes = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -143,11 +143,11 @@ def test_basic_construction_all_same_shape(self) -> None: def test_basic_construction_varying_first_dim(self) -> None: """Test GroupedTensor construction with varying first dimension""" num_tensors = 3 - shape = [(128, 512), (256, 512), (384, 512)] + shapes = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -157,20 +157,20 @@ def test_basic_construction_varying_first_dim(self) -> None: assert not grouped_tensor.all_same_shape() assert not grouped_tensor.all_same_first_dim() assert grouped_tensor.all_same_last_dim() - assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.get_common_last_dim() == shapes[0][1] assert grouped_tensor.logical_shape == ( - sum(v for v, _ in shape), - shape[0][1], + sum(v for v, _ in shapes), + shapes[0][1], ) # sum of first dims def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 - shape = [(256, 512) for _ in range(num_tensors)] + shapes = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -186,7 +186,7 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: # Verify each tensor has correct shape and shares storage for i, tensor in enumerate(tensors): - assert tensor.shape == shape[i] + assert tensor.shape == shapes[i] assert isinstance(tensor, torch.Tensor) assert not hasattr(tensor, "_data") # Not a quantized tensor @@ -195,20 +195,20 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: assert tensor.data_ptr() >= original_data_ptr # Calculate expected offset - expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + expected_offset = i * (shapes[i][0] * shapes[i][1]) * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset @pytest.mark.parametrize("quantization", _quantization_params) def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: """Test split_into_quantized_tensors for quantized tensors""" num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shapes) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) @@ -225,18 +225,18 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None rowwise_data = _get_rowwise_data_tensor(tensor, quantization) assert rowwise_data is not None assert rowwise_data.data_ptr() >= original_data_ptr - numel = shape[i][0] * shape[i][1] + numel = shapes[i][0] * shapes[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset def test_split_varying_shapes(self) -> None: """Test split_into_quantized_tensors with varying shapes""" num_tensors = 3 - shape = [(128, 512), (256, 512), (384, 512)] + shapes = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, @@ -250,22 +250,22 @@ def test_split_varying_shapes(self) -> None: # Verify shapes and storage cumulative_offset = 0 for i, tensor in enumerate(tensors): - assert tensor.shape == shape[i] + assert tensor.shape == shapes[i] expected_offset = cumulative_offset * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset - cumulative_offset += shape[i][0] * shape[i][1] + cumulative_offset += shapes[i][0] * shapes[i][1] @pytest.mark.parametrize("quantization", _quantization_params) def test_quantize_inplace(self, quantization: str) -> None: """Test that quantize is done in-place for all recipes""" num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shapes) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) @@ -277,7 +277,7 @@ def test_quantize_inplace(self, quantization: str) -> None: ) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -291,7 +291,7 @@ def test_quantize_inplace(self, quantization: str) -> None: # Verify returned tensors point to the same storage for i, qtensor in enumerate(quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shape[i][0] * shape[i][1] + numel = shapes[i][0] * shapes[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -299,13 +299,13 @@ def test_quantize_inplace(self, quantization: str) -> None: def test_quantize_varying_shapes(self, quantization: str) -> None: """Test quantize with varying shapes""" num_tensors = 3 - shape = [(256, 512), (512, 512), (768, 512)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(256, 512), (512, 512), (768, 512)] + quantizer = make_quantizer(quantization, num_tensors, shapes) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) @@ -313,7 +313,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() # Create input tensors with varying shapes - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -323,7 +323,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: # Verify each tensor points to correct location cumulative_numel = 0 - for qtensor, tensor_shape in zip(quantized_tensors, shape): + for qtensor, tensor_shape in zip(quantized_tensors, shapes): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -333,16 +333,16 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: def test_static_quantize_method(self, quantization: str) -> None: """Test the static quantize method""" num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + shapes = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shapes) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] # Use static quantize method grouped_tensor = GroupedTensor.create_and_quantize( tensors=input_tensors, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -357,7 +357,7 @@ def test_static_quantize_method(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() for i, qtensor in enumerate(grouped_tensor.quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shape[i][0] * shape[i][1] + numel = shapes[i][0] * shapes[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -370,18 +370,16 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: """Test grouped quantization for MXFP8 against per-tensor quantization.""" # Test wont pass until the grouped quantization PR from Oleg is merged. num_tensors = 2 - shape = [(512, 1024) for _ in range(num_tensors)] + shapes = [(512, 1024) for _ in range(num_tensors)] # Create BF16 input tensors and pack into a grouped tensor - input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] - quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] - for q in quantizers: - q.optimize_for_gemm=True - quantized_tensors = [q(tensor) for q, tensor in zip(quantizers, input_tensors)] + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shapes] + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.optimize_for_gemm=True grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizers=None, + shapes=shapes, + quantizer=None, device="cuda", dtype=torch.bfloat16, ) @@ -392,30 +390,18 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) offset += numel - # Create MXFP8 output grouped tensor (rowwise only for easier validation) - # quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizers=quantizers, + shapes=shapes, + quantizer=quantizer, device="cuda", ) - print(f">>>>>>>>>>>> tex.quantize_grouped") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.data.shape if grouped_input.data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.data.shape if grouped_output.data is not None else None}") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.scale_inv.shape if grouped_input.scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.scale_inv.shape if grouped_output.scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_data.shape if grouped_input.columnwise_data is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_data.shape if grouped_output.columnwise_data is not None else None}") - print(f">>>>>>>>>>>> grouped_input: {grouped_input.columnwise_scale_inv.shape if grouped_input.columnwise_scale_inv is not None else None}") - print(f">>>>>>>>>>>> grouped_output: {grouped_output.columnwise_scale_inv.shape if grouped_output.columnwise_scale_inv is not None else None}") # Quantize using grouped API (handle both 2-arg and 3-arg bindings) _ = tex.quantize_grouped(grouped_input, grouped_output) # Build expected output by quantizing each tensor independently expected_data = [] expected_scale_inv = [] - for tensor, quantizer in zip(input_tensors, quantizers): + for tensor in input_tensors: qtensor = quantizer(tensor) expected_data.append(qtensor._rowwise_data.reshape(-1)) expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) @@ -429,11 +415,11 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: def test_clear(self) -> None: """Test clear method""" num_tensors = 3 - shape = [(256, 512) for _ in range(num_tensors)] + shapes = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shapes, quantizer=None, device="cuda", dtype=torch.float32, diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index ea81e6c516..6a6715bdcc 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -244,7 +244,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { -// printf(">>>>>>>>>>>> WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -853,7 +852,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; - printf(">>>>>>>>>>>> with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); + printf(">>>>>>>>>>>> group_quantize_mxfp8 with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = @@ -935,19 +934,19 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations OType, true, true, WITH_GEMM_SWIZZLED_SCALES>; switch (scaling_type) { case ScalingType::ROWWISE: { - printf(">>>>>>>>>>>> grouped: ScalingType::ROWWISE\n"); + printf(">>>>>>>>>>>> group_quantize_mxfp8 ScalingType::ROWWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::COLWISE: { - printf(">>>>>>>>>>>> grouped: ScalingType::COLWISE\n"); + printf(">>>>>>>>>>>> group_quantize_mxfp8 ScalingType::COLWISE\n"); kernel = group_quantize_mxfp8_kernel; break; } case ScalingType::BIDIMENSIONAL: { - printf(">>>>>>>>>>>> grouped: ScalingType::BIDIMENSIONAL\n"); + printf(">>>>>>>>>>>> group_quantize_mxfp8 ScalingType::BIDIMENSIONAL\n"); kernel = group_quantize_mxfp8_kernel; break; diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 82bf497a3b..a8135391e3 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -55,7 +55,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { -// printf(">>>>>>>>>>>> non-grouped: WITH_GEMM_SWIZZLED_SCALES: %d\n", WITH_GEMM_SWIZZLED_SCALES); #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -777,7 +776,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, switch (scaling_type) { case ScalingType::ROWWISE: { - printf(">>>>>>>>>>>> non-grouped: ScalingType::ROWWISE\n"); + printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::ROWWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -793,7 +792,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { - printf(">>>>>>>>>>>> non-grouped: ScalingType::COLWISE\n"); + printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::COLWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -809,7 +808,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { - printf(">>>>>>>>>>>> non-grouped: ScalingType::BIDIMENSIONAL\n"); + printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::BIDIMENSIONAL\n"); auto kernel = quantize_mxfp8_kernel; diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3ed540b8f2..1adabcded2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -234,16 +234,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - printf(">>>>>> qkv_layout: %d\n", qkv_layout); - printf(">>>>>> q_dtype: %d\n", q_dtype); - printf(">>>>>> qkv_format: %d\n", qkv_format); - printf(">>>>>> q_format: %d\n", q_format); - printf(">>>>>> kv_format: %d\n", kv_format); - printf(">>>>>> layout_group: %d\n", layout_group); - printf(">>>>>> cudnn_runtime_version: %d\n", cudnn_runtime_version); - printf(">>>>>> is_training: %d\n", is_training); - printf(">>>>>> bias_type: %d\n", bias_type); - printf(">>>>>> attn_mask_type: %d\n", attn_mask_type); + printf(">>>>>> nvte_get_fused_attn_backend qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend q_dtype: %d, %d, %d\n", q_dtype, NVTEDType::kNVTEFloat8E4M3, NVTEDType::kNVTEFloat8E5M2); + printf(">>>>>> nvte_get_fused_attn_backend qkv_format: %d, %d, %d\n", qkv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend q_format: %d, %d, %d\n", q_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend kv_format: %d, %d, %d\n", kv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); + printf(">>>>>> nvte_get_fused_attn_backend layout_group: %d, %d, %d\n", layout_group, NVTE_QKV_Layout_Group::NVTE_SD_SD_SD, NVTE_QKV_Layout_Group::NVTE_HD_HD_HD); + printf(">>>>>> nvte_get_fused_attn_backend cudnn_runtime_version: %d\n", cudnn_runtime_version); + printf(">>>>>> nvte_get_fused_attn_backend is_training: %d\n", is_training); + printf(">>>>>> nvte_get_fused_attn_backend bias_type: %d\n", bias_type); + printf(">>>>>> nvte_get_fused_attn_backend attn_mask_type: %d, %d, %d\n", attn_mask_type, NVTE_Mask_Type::NVTE_NO_MASK, NVTE_Mask_Type::NVTE_CAUSAL_MASK); if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && @@ -270,7 +270,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BHSD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -531,6 +530,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } + printf(">>>>>> nvte_get_fused_attn_backend fused_attention_backend: %d\n", backend); return backend; } @@ -1202,7 +1202,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, false); - printf(">>>>>> fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f7698da5c3..71d86843b5 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1681,21 +1681,13 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); - printf(">>>>>> cudnn_frontend::DataType_t::UINT8: %d\n", cudnn_frontend::DataType_t::UINT8); - printf(">>>>>> cudnn_frontend::DataType_t::INT8: %d\n", cudnn_frontend::DataType_t::INT8); - printf(">>>>>> cudnn_frontend::DataType_t::HALF: %d\n", cudnn_frontend::DataType_t::HALF); - printf(">>>>>> cudnn_frontend::DataType_t::INT64: %d\n", cudnn_frontend::DataType_t::INT64); - printf(">>>>>> cudnn_frontend::DataType_t::DOUBLE: %d\n", cudnn_frontend::DataType_t::DOUBLE); - printf(">>>>>> bias_type: %d\n", bias_type); - printf(">>>>>> mask_type: %d\n", mask_type); - printf(">>>>>> scaling_factor: %f\n", scaling_factor); - printf(">>>>>> dropout_probability: %f\n", dropout_probability); try { FADescriptor_v1 descriptor{b, @@ -1777,11 +1769,9 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - printf(">>>>>> layout: %d\n", layout); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, @@ -1979,16 +1969,10 @@ void fused_attn_fp8_fwd_impl_v1( : std::make_tuple(nullptr, nullptr); NVTE_CHECK_CUDNN_FE(mha_graph->validate()); - printf(">>>>>> mha_graph->validate()\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); - printf(">>>>>> mha_graph->build_operation_graph(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); - printf(">>>>>> mha_graph->create_execution_plans({fe::HeurMode_t::A})\n"); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); - printf(">>>>>> mha_graph->check_support(handle)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - printf(">>>>>> mha_graph->build_plans(handle)\n"); - printf(">>>>>> mha_graph->get_workspace_size(): %zu\n", mha_graph->get_workspace_size()); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index da1bf03b02..950d67155b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -120,7 +120,7 @@ def float8_block_scaling(cls): @classmethod def custom(cls): """Whether the given recipe is custom.""" - return isinstance(self, CustomRecipe) + return isinstance(cls, CustomRecipe) @dataclass() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5cc23eabd8..47f7e0f222 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,6 +29,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, @@ -174,15 +175,23 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] + assert qkv_layout == "sbhd_sbhd_sbhd", "sbhd_sbhd_sbhd is assumed to be the shape always at this point in UnfusedDotProductAttention." q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( - qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype, des_nominal_dtype=query_layer.dtype ) + if isinstance(quantizer, MXFP8Quantizer): + assert qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: - t_fp8 = quantizer(tensor1) - tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + if quantizer is not None: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) else: tensors = (tensor1, tensor2, tensor3) ctx.quantizer = quantizer @@ -376,6 +385,7 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) + apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 if "padding" in attn_mask_type and attention_mask is None: attention_mask = dpa_utils.get_padding_mask( @@ -402,9 +412,6 @@ def forward( ) ) - batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] - apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 - # [b, np, sq, sk] output_size = ( query_layer.size(1), @@ -424,11 +431,6 @@ def forward( int(query_layer.shape[2] / value_layer.shape[2]), dim=2 ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) - # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], @@ -447,6 +449,11 @@ def forward( QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers) ) + # disable swizzle for MXFP8Quantizer + for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: + if isinstance(q, MXFP8Quantizer): + q.optimize_for_gemm = False + q.internal = False # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: @@ -459,18 +466,21 @@ def forward( fp8_dtype=dP_quantizer.dtype, device="cuda" ) - if "2" in qkv_layout or "3" in qkv_layout: - qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) - qkv_layout = "_".join([qkv_format] * 3) + # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", "sbhd_sbhd_sbhd" ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", "sbhd_sbhd_sbhd" ) + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) + # Raw attention scores. [b * np, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( @@ -600,14 +610,14 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(seqlen, batch_size, -1) + context_layer = context_layer.view(max_seqlen_q, batch_size, -1) if q_format == "bshd": # [b, np, sq, hn] --> [b, sq, np, hn] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [b, sq, np, hn] --> [b, sq, hp] - context_layer = context_layer.view(batch_size, seqlen, -1) + context_layer = context_layer.view(batch_size, max_seqlen_q, -1) if q_format == "thd": # [b, np, sq, hn] --> [b, sq, np, hn] @@ -1207,6 +1217,9 @@ def forward( # whether bwd kernel in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # save original qkv_layout + original_qkv_layout = qkv_layout + # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers) @@ -1217,20 +1230,18 @@ def forward( # FP8 attention: torch.float16 or torch.bfloat16 out_nominal_dtype = q.dtype - # save original qkv_layout - original_qkv_layout = qkv_layout - max_logit = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E4M3 + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; + # dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(original_qkv_layout, q, k, v, QKV_quantizer) # print quantizers print_quantizers( @@ -1248,6 +1259,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, @@ -1280,27 +1292,20 @@ def forward( cuda_graph=is_graph_capturing(), ) if original_qkv_layout != qkv_layout: - print(f">>>>>>>>>>>> original_qkv_layout: {original_qkv_layout}") - print(f">>>>>>>>>>>> qkv_layout: {qkv_layout}") - print(f">>>>>>>>>>>> out_.shape: {out_.shape}") original_qkv_format = original_qkv_layout.split("_")[0] new_qkv_format = qkv_layout.split("_")[0] perm = [] for i in new_qkv_format: perm.append(original_qkv_format.find(i)) out_ = out_.permute(*perm).contiguous() - print(f">>>>>>>>>>>> out_.shape permuted: {out_.shape}") - # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ print(f"out_: {type(out_)} {out_.shape}") - print(f"is_output_fp8: {is_output_fp8}") - print(f"is_bwd_fp8: {is_bwd_fp8}") - print(f"fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}") - print(f"_dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") + print(f"is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}, fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}, _dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): print(f"dequantizing out_") if not is_output_fp8 or not is_bwd_fp8: @@ -1451,6 +1456,7 @@ def forward( else: ctx.qkv_layout = qkv_layout else: + ctx.original_qkv_layout = original_qkv_layout ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type @@ -1539,6 +1545,14 @@ def backward(ctx, d_out, *_args): # FP8 attention: torch.float16 or torch.bfloat16 dqkv_nominal_dtype = ctx.nominal_dtype + if ctx.original_qkv_layout != ctx.qkv_layout: + original_qkv_format = ctx.original_qkv_layout.split("_")[0] + new_qkv_format = ctx.qkv_layout.split("_")[0] + perm = [] + for i in original_qkv_format: + perm.append(new_qkv_format.find(i)) + d_out = d_out.permute(*perm).contiguous() + if ctx.fp8: # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1567,20 +1581,20 @@ def backward(ctx, d_out, *_args): # fp8_dtype = tex.DType.kFloat8E4M3 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # out_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # DelayedScaling: + # out_, dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # - # dq_, dk_, dv_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_ = ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8 - ) + # Float8CurrentScaling: + # out_, dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: + # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_ = out + if ctx.fp8_recipe.mxfp8_block_scaling(): + out_ = out + aux_ctx_tensors.append(d_out) + dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1615,8 +1629,8 @@ def backward(ctx, d_out, *_args): # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_float8tensor = isinstance(dq_, Float8Tensor) - if is_float8tensor and not ctx.is_input_fp8: + is_quantized_tensor = isinstance(dq_, QuantizedTensor) + if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( ctx.qkv_layout, @@ -1627,7 +1641,7 @@ def backward(ctx, d_out, *_args): ) if not is_float8tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv, ctx.qkv_layout = combine_and_quantize( + dq, dk, dv, _ = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a5931188dc..244f24111d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1392,7 +1392,7 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index eb905d7b93..55553d30be 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -583,9 +583,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - print(f"fp8_recipe: {fp8_recipe}") - # if fp8_recipe.custom(): - # return + if fp8_recipe.custom(): + return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c5ba652c28..d3c2e01814 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,12 +35,14 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -2099,6 +2101,8 @@ def get_attention_quantizers(fp8, quantizers): O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer.set_usage(rowwise=True, columnwise=False) if isinstance(QKV_quantizer, MXFP8Quantizer): + QKV_quantizer.optimize_for_gemm = True + # QKV_quantizer.internal = False S_quantizer = None else: S_quantizer = quantizers["scaling_fwd"][META_S] @@ -2184,67 +2188,49 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_format, _, _ = get_qkv_format(qkv_layout) + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype - print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") if isinstance(qkv_quantizer, MXFP8Quantizer): - # bs3hd -> bshd_bshd_bhsd - q,k,v = [x.contiguous() for x in [q, k, v]] - print(f">>>>>>>>>>>> Contiguous shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") - - # bshd_bshd_bhsd -> bhsd_bhsd_bhsd - # thd_thd_thd -> htd_htd_htd - qkv_quantizer.optimize_for_gemm = True - qkv_quantizer.internal = False - print(f">>>>>>>>>>>> qkv_quantizer.internal: {qkv_quantizer.internal}") - dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") - def permute_x(x): - dim_others = [i for i in range(len(x.shape)) if i != dim_s] - perm = [*dim_others[:-1], dim_s, dim_others[-1]] + print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") + + def permute_x(f, x): + x = x.contiguous() if not x.is_contiguous() else x + dim_s_dim_t = f.find("s") if 's' in f else f.find("t") + dim_others = [i for i in range(len(x.shape)) if i != dim_s_dim_t] + perm = [*dim_others[:-1], dim_s_dim_t, dim_others[-1]] x = x.permute(*perm).contiguous() return x - q, k, v = [permute_x(x) for x in [q, k, v]] + + # bs3hd, sb3hd, etc -> bshd_bshd_bhsd -> bhsd_bhsd_bhsd + # t3hd, etc -> thd_thd_thd -> htd_htd_htd + if q_format not in ["bhsd", "htd"]: + q = permute_x(q_format, q) + if kv_format not in ["bhsd", "htd"]: + k = permute_x(kv_format, k) + v = permute_x(kv_format, v) print(f">>>>>>>>>>>> Permuted shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") + qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" - original_shapes = [q.shape, k.shape, v.shape] - b, h_q, s_q, d_qk = q.shape - _, h_kv, s_kv, d_kv = v.shape - assert k.shape == (b, h_kv, s_kv, d_qk) + original_shapes = [x.shape for x in [q, k, v]] + s_q, d_qk = q.shape[-2:] + s_kv, d_kv = v.shape[-2:] assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 assert d_kv % 32 == 0 + # need to check seqlens in THD % 128 == 0 q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # consider bhsd for now grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) - print(f">>>>>>>>>>>> grouped_tensor: {type(grouped_tensor)}") - print(f">>>>>>>>>>>> grouped_tensor.quantized_tensors: {type(grouped_tensor.quantized_tensors)}") q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - q_fp8, k_fp8, v_fp8 = [x.view(*original_shapes[i]) for i, x in enumerate([q_fp8, k_fp8, v_fp8])] - print(f">>>>>>>>>>>> q_fp8: {q_fp8.shape}, k_fp8: {k_fp8.shape}, v_fp8: {v_fp8.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") - # dim_s = qkv_format.find("s") if 's' in qkv_format else qkv_format.find("t") - # dim_others = [i for i in range(len(v.shape)) if i != dim_s] - # perm = [*dim_others, dim_s] - # # perm = [*dim_others[:-1], dim_s, dim_others[-1]] - # v = v.permute(*perm).contiguous() - - qkv_layout = "bhsd_bhsd_bhsd" - - # inv = [0] * len(perm) - # for i, p in enumerate(perm): - # inv[p] = i - # # v = v.permute(*inv) return q_fp8, k_fp8, v_fp8, qkv_layout @@ -2296,14 +2282,20 @@ def combine_and_dequantize( """Combine q,k,v based on qkv_layout and dequantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) qkv_group = len(qkv_layout.split("_")) - if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + if all(isinstance(x, QuantizedTensor) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" if des_nominal_dtype is None: des_nominal_dtype = src_nominal_dtype + if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): + print(f"Combining and dequantizing q, k, v from MXFP8 to {des_nominal_dtype}") + q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] + return q, k, v + q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] match qkv_group: case 1: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b4811eb4f6..09953440e9 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -16,7 +16,6 @@ NVTE_Fused_Attn_Backend, ) from ..quantized_tensor import Quantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer __all__ = [ diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index e4c658ed58..6c72d74531 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -75,8 +75,7 @@ def update_quantized( src = src.contiguous() # Launch cast kernel - print(f">>>>>>>>>>>> src: {src.shape}") - print(f">>>>>>>>>>>> dst: {dst.shape}") + print(f"MXFP8Quantizer.update_quantized: src: {src.shape}, dst: {dst.shape}") tex.quantize(src, self, dst, noop_flag) # Update FP8 dtype @@ -86,7 +85,7 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - print(f"Quantizing tensor: {tensor.shape}") + print(f"MXFP8Quantizer.quantize_impl: tensor: {tensor.shape}") return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 9771d61df8..5a8d323983 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -91,24 +91,6 @@ def __init__( offsets: Vector of integer offsets for each tensor. logical_shape: 2D tuple representing conceptual shape """ - # print(f">>>>>>>>>>>> GroupedTensor init: {quantizers}") - # print(f">>>>>>>>>>>> shape: {shape}") - # print(f">>>>>>>>>>>> dtype: {dtype}") - # print(f">>>>>>>>>>>> data: {data.shape}") - # print(f">>>>>>>>>>>> columnwise_data: {columnwise_data.shape if columnwise_data is not None else None}") - # print(f">>>>>>>>>>>> scale_inv: {scale_inv.shape if scale_inv is not None else None}") - # print(f">>>>>>>>>>>> columnwise_scale_inv: {columnwise_scale_inv.shape if columnwise_scale_inv is not None else None}") - # print(f">>>>>>>>>>>> amax: {amax.shape if amax is not None else None}") - # print(f">>>>>>>>>>>> columnwise_amax: {columnwise_amax.shape if columnwise_amax is not None else None}") - # print(f">>>>>>>>>>>> scale: {scale.shape if scale is not None else None}") - # print(f">>>>>>>>>>>> first_dims: {first_dims.shape}") - # print(f">>>>>>>>>>>> last_dims: {last_dims.shape}") - # print(f">>>>>>>>>>>> tensor_offsets: {tensor_offsets.shape}") - # print(f">>>>>>>>>>>> offsets: {offsets.shape}") - # print(f">>>>>>>>>>>> scale_inv_offsets: {scale_inv_offsets.shape}") - # print(f">>>>>>>>>>>> columnwise_scale_inv_offsets: {columnwise_scale_inv_offsets.shape}") - # print(f">>>>>>>>>>>> logical_shape: {logical_shape.shape}") - print(f">>>>>>>>>>>> num_tensors: {num_tensors}") self.num_tensors = num_tensors self.quantizer = quantizer @@ -519,7 +501,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) if i < num_tensors - 1: @@ -554,7 +536,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shape): + for i, s in enumerate(shapes): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) if i < num_tensors - 1: @@ -934,7 +916,7 @@ def create_and_quantize( Quantize given tensors into quantized tensors with underlying storage allocated in a GroupedTensor. """ - + print(f">>>>>>>>>>>> GroupedTensor create_and_quantize") grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=len(tensors), shapes=[t.shape for t in tensors], @@ -960,15 +942,6 @@ def create_and_quantize( _ = tex.quantize_grouped(grouped_input, grouped_output) grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() - # grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - # num_tensors=len(tensors), - # shapes=[t.shape for t in tensors], - # quantizer=None, - # device=device, - # dtype=tensors[0].dtype, - # ) - # grouped_tensor.quantize(tensors, noop_flag=noop_flag) - return grouped_output def quantize( @@ -979,10 +952,8 @@ def quantize( """ Quantize the GroupedTensor inplace. """ - + print(f">>>>>>>>>>>> GroupedTensor quantize") self.quantized_tensors = self.split_into_quantized_tensors() - print(f">>>>>>>>>>>> tensors[0]: {type(tensors[0])}") - print(f">>>>>>>>>>>> quantized_tensors[0]: {type(self.quantized_tensors[0])}") for i in range(self.num_tensors): self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) return self.quantized_tensors \ No newline at end of file From cd06398d2c57d021c31330318eb40ca8567578d4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:54:12 -0800 Subject: [PATCH 10/59] split d to d_qk/d_v; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 30 +- .../common/fused_attn/fused_attn_fp8.cu | 434 +++++++++++++----- .../common/fused_attn/fused_attn_fp8.h | 6 +- 3 files changed, 354 insertions(+), 116 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1adabcded2..98ff96b666 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -639,7 +639,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -772,6 +772,10 @@ void nvte_fused_attn_bwd_qkvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + const Tensor *input_dO_f16; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + } // Unpack QKV and dQKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; @@ -787,8 +791,8 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -945,7 +949,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1090,6 +1094,10 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + const Tensor *input_dO_f16; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + } // Unpack KV and dKV and call the non-packed function const auto Q_type = input_Q->data.dtype; @@ -1104,9 +1112,9 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, + input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1228,7 +1236,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1340,9 +1348,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + const Tensor *input_dO_f16; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + } + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 71d86843b5..d9af04c628 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,7 +1652,7 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, @@ -1681,7 +1681,7 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d: %d\n", b, h, hg, s_q, s_kv, d); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); @@ -1695,8 +1695,8 @@ void fused_attn_fp8_fwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -1772,36 +1772,39 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); int32_t block_size = 32; - int64_t d_scale = (d + block_size - 1) / block_size; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; int64_t s_q_padded = ((s_q + 127) / 128) * 128; int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; - int64_t d_scale_padded = ((d_scale + 3) / 4) * 4; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_padded = ((d + 3) / 4) * 4; - printf(">>>>>> d_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_scale_padded: %d, s_kv_scale_padded: %d, d_padded: %d\n", d_scale, s_kv_scale, s_q_padded, s_kv_padded, d_scale_padded, s_kv_scale_padded, d_padded); - std::vector q_scale_dims = {b, h, s_q_padded, d_scale_padded}; - std::vector k_scale_dims = {b, hg, s_kv_padded, d_scale_padded}; - std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_padded}; + int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; + int64_t d_v_padded = ((d_v + 3) / 4) * 4; + printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); + std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; + std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; + std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_v_padded}; std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_scale_padded, q_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_scale_padded, k_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); @@ -1809,17 +1812,17 @@ void fused_attn_fp8_fwd_impl_v1( Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_data_type(qkv_tensor_type)); @@ -1931,9 +1934,9 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); + O->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2050,7 +2053,7 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, @@ -2058,10 +2061,11 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, + void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2075,13 +2079,24 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || - dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> scaling_mode: %d\n", scaling_mode); + printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); + printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); + printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); + printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf(">>>>>> o_tensor_type: %d, %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); + printf(">>>>>> do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); + printf(">>>>>> dqkv_tensor_type: %d, %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2091,8 +2106,8 @@ void fused_attn_fp8_bwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -2122,17 +2137,24 @@ void fused_attn_fp8_bwd_impl_v1( using graph_and_tensors = std::tuple, std::shared_ptr, // q + std::shared_ptr, // q_t std::shared_ptr, // k + std::shared_ptr, // k_t std::shared_ptr, // v std::shared_ptr, // o std::shared_ptr, // stats std::shared_ptr, // dO + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2181,40 +2203,45 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr scale_dQ, scale_dK, scale_dV; std::shared_ptr bias, dBias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + std::shared_ptr q_t, k_t, dO_t, dO_f16, descale_q_t, descale_k_t, descale_dO_t; std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(o_stride) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(o_stride) + .set_data_type(do_tensor_type)); stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("stats") .set_dim({b, h, s_q, 1}) @@ -2228,33 +2255,151 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { - descale_o = mha_graph->tensor(1.0f); + if (!is_mxfp8) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + if (is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } } else { - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); - } - descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - - if (is_delayed_scaling) { - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); - } - if (is_current_scaling) { - scale_dQ = mha_graph->tensor(1.0f); - scale_dK = mha_graph->tensor(1.0f); - scale_dV = mha_graph->tensor(1.0f); + std::vector q_t_stride(4); + std::vector k_t_stride(4); + std::vector dO_t_stride(4); + generateMatrixStrides(b, h, d_qk, s_kv, s_q, q_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + generateMatrixStrides(b, h, d_qk, s_kv, s_q, dO_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + int32_t block_size = 32; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t s_q_scale = (s_q + block_size - 1) / block_size; + int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; + int64_t d_v_padded = ((d_v + 3) / 4) * 4; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; + int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; + printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_q_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_q_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_q_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); + // std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; + // std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; + // std::vector v_scale_dims = {b, hg, s_kv_padded, d_v_scale_padded}; + // std::vector q_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; + // std::vector k_t_scale_dims = {b, hg, s_kv_scale_padded, d_qk_padded}; + // // std::vector dO_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; + // // std::vector dO_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + std::vector q_t_scale_strides(4); + std::vector k_t_scale_strides(4); + // std::vector dO_scale_strides(4); + // std::vector dO_t_scale_strides(4); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, h, d_qk_padded, s_kv_scale_padded, s_q_scale_padded, q_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + printf(">>>>>> q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); + printf(">>>>>> k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); + + q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_stride) + .set_data_type(qkv_tensor_type)); + k_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_stride) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(dO_t_stride) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_qk}) + .set_stride(dO_t_stride) + .set_data_type(o_tensor_type)); + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q_t") + .set_dim({b, h, s_q_scale_padded, d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k_t") + .set_dim({b, hg, s_kv_scale_padded, d_qk_padded}) + .set_stride(k_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, s_kv_padded, d_v_scale_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO") + .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO_t") + .set_dim({b, h, s_q_scale_padded, d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; @@ -2312,14 +2457,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - + // if (!is_mxfp8) { auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); - - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + // } else { + // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( + // q, q_t, k, k_t, v, o, dO_f16, dO, dO_t, stats, descale_q, descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, descale_dO_t, + // sdpa_backward_options); + // } + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride).set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2332,28 +2481,36 @@ void fused_attn_fp8_bwd_impl_v1( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + if (!is_mxfp8) { amax_dP->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - - dO->set_data_type(do_tensor_type); - dQ->set_data_type(dqkv_tensor_type); - dK->set_data_type(dqkv_tensor_type); - dV->set_data_type(dqkv_tensor_type); + } + // dO->set_data_type(do_tensor_type); + // dQ->set_data_type(dqkv_tensor_type); + // dK->set_data_type(dqkv_tensor_type); + // dV->set_data_type(dqkv_tensor_type); std::tuple, // q + // std::shared_ptr, // q_t std::shared_ptr, // k + // std::shared_ptr, // k_t std::shared_ptr, // v std::shared_ptr, // o std::shared_ptr, // stats std::shared_ptr, // dO + // std::shared_ptr, // dO_t + // std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + // std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + // std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + // std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2372,6 +2529,8 @@ void fused_attn_fp8_bwd_impl_v1( q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); + auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + // key_tensors_tuple = std::tuple_cat(key_tensors_tuple, mxfp8_tensors_tuple); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2385,17 +2544,64 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + padding_tuple, dropout_tuple, mxfp8_tensors_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - - auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, - descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); - + auto bprop_tuple = get_graph(sdpa_fp8_bprop_cache, descriptor); + // if (!is_mxfp8) { + // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, + // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, + // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + auto mha_graph = std::get<0>(bprop_tuple); + auto q = std::get<1>(bprop_tuple); + auto k = std::get<2>(bprop_tuple); + auto v = std::get<3>(bprop_tuple); + auto o = std::get<4>(bprop_tuple); + auto stats = std::get<5>(bprop_tuple); + auto dO = std::get<6>(bprop_tuple); + auto attn_scale = std::get<7>(bprop_tuple); + auto descale_q = std::get<8>(bprop_tuple); + auto descale_k = std::get<9>(bprop_tuple); + auto descale_v = std::get<10>(bprop_tuple); + auto descale_o = std::get<11>(bprop_tuple); + auto descale_dO = std::get<12>(bprop_tuple); + auto descale_s = std::get<13>(bprop_tuple); + auto descale_dP = std::get<14>(bprop_tuple); + auto scale_s = std::get<15>(bprop_tuple); + auto scale_dQ = std::get<16>(bprop_tuple); + auto scale_dK = std::get<17>(bprop_tuple); + auto scale_dV = std::get<18>(bprop_tuple); + auto scale_dP = std::get<19>(bprop_tuple); + auto dQ = std::get<20>(bprop_tuple); + auto dK = std::get<21>(bprop_tuple); + auto dV = std::get<22>(bprop_tuple); + auto amax_dQ = std::get<23>(bprop_tuple); + auto amax_dK = std::get<24>(bprop_tuple); + auto amax_dV = std::get<25>(bprop_tuple); + auto amax_dP = std::get<26>(bprop_tuple); + auto bias = std::get<27>(bprop_tuple); + auto dBias = std::get<28>(bprop_tuple); + auto seq_q = std::get<29>(bprop_tuple); + auto seq_kv = std::get<30>(bprop_tuple); + auto dropout_seed = std::get<31>(bprop_tuple); + auto dropout_offset = std::get<32>(bprop_tuple); + // } else { + // if (is_mxfp8) { + // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, + // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, + // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + auto q_t = std::get<33>(bprop_tuple); + auto k_t = std::get<34>(bprop_tuple); + auto dO_f16 = std::get<35>(bprop_tuple); + auto dO_t = std::get<36>(bprop_tuple); + auto descale_q_t = std::get<37>(bprop_tuple); + auto descale_k_t = std::get<38>(bprop_tuple); + auto descale_dO_t = std::get<39>(bprop_tuple); + // } auto plan_workspace_size = mha_graph->get_workspace_size(); // Exit to request upper level API to allocate memory if needed @@ -2422,25 +2628,36 @@ void fused_attn_fp8_bwd_impl_v1( {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, {descale_dO, devPtrDescaledO}, - {descale_s, devPtrDescaleS}, - {descale_dP, devPtrDescaledP}, - {scale_s, devPtrScaleS}, - {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, {amax_dQ, devPtrAmaxdQ}, {amax_dK, devPtrAmaxdK}, {amax_dV, devPtrAmaxdV}, - {amax_dP, devPtrAmaxdP}, }; + if (!is_mxfp8) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + } else { + variant_pack[q_t] = devPtrQ_t; + variant_pack[k_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + variant_pack[descale_dO] = devPtrDescaledO; + variant_pack[descale_dO_t] = devPtrDescaledO_t; + } if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (!is_O_in_F16) { + if (is_current_scaling && !is_O_in_F16) { variant_pack[descale_o] = devPtrDescaleO; } @@ -2485,7 +2702,7 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, @@ -2576,7 +2793,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, @@ -2585,7 +2802,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, @@ -2609,11 +2826,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, - const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, + const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, @@ -2626,6 +2843,10 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleQ = input_Q->scale_inv.dptr; void* devPtrDescaleK = input_Q->scale_inv.dptr; void* devPtrDescaleV = input_Q->scale_inv.dptr; + void* devPtrQ_t = input_Q->columnwise_data.dptr; + void* devPtrK_t = input_K->columnwise_data.dptr; + void* devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; + void* devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; void* devPtrO = input_O->data.dptr; const DType O_type = input_O->data.dtype; @@ -2635,6 +2856,9 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; + void* devPtrdO_t = input_dO->columnwise_data.dptr; + void* devPtrdO_f16 = input_dO_f16->data.dptr; + void* devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; void* devPtrM = input_M->data.dptr; void* devPtrZInv = input_ZInv->data.dptr; @@ -2672,18 +2896,20 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index a1a932fdf5..f335bc3d85 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,7 +15,7 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, @@ -26,11 +26,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, + const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, From 730a472af86c2a47d996035f9a8bd5e4c409c0f0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:55:59 -0800 Subject: [PATCH 11/59] fix last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 3 --- transformer_engine/common/common.h | 4 +--- transformer_engine/pytorch/tensor/storage/grouped_tensor.py | 4 ---- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 64896810f8..7a8ab8062c 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -852,9 +852,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; - printf(">>>>>>>>>>>> group_quantize_mxfp8 with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b14653aca7..2d7f0e7e8c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -355,9 +355,7 @@ struct GroupedTensor { last_dims(nullptr, std::vector{0}, DType::kInt64), tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), - scaling_mode(scaling_mode), - nvte_tensor(0), - with_gemm_swizzled_scales(false) {} + nvte_tensor(0) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index b6c8818ab8..123dfcf22a 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -8,12 +8,8 @@ import math import torch -<<<<<<< HEAD import transformer_engine import transformer_engine_torch as tex -======= - ->>>>>>> main from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor From d9ff5662aa4b4b6267c77baf614aada6602fa133 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:56:47 -0800 Subject: [PATCH 12/59] update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 ++- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 209a25fe89..4b4df2edcf 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 +Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d From 2b264d72f663707ccb923d7259603c53872306d3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:00:51 -0800 Subject: [PATCH 13/59] attempt at SWA/MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 13 +++++-- .../common/fused_attn/fused_attn.cpp | 18 +++++---- .../common/fused_attn/fused_attn_fp8.cu | 39 ++++++++++++------- .../common/fused_attn/fused_attn_fp8.h | 4 +- .../attention/dot_product_attention/utils.py | 22 +++++------ 5 files changed, 57 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 74deeceed2..05d76d96fe 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1789,7 +1789,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 16, 128), - "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), + "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), @@ -2259,7 +2259,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -2304,7 +2304,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim_qk, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -2320,6 +2321,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] if config.dropout_p == 0.0: tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") @@ -2344,6 +2349,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") @@ -2354,6 +2360,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: inp[1], inp[2], qkv_format=qkv_format, + window_size=config.window_size, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=config.max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 98ff96b666..6f343d90b2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -269,7 +269,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: mxfp8, d_qk=128, d_v=192 + (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -425,7 +427,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && window_size_right == -1 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + ((window_size_left == -1 || window_size_left >= 0) && (window_size_right == -1 || window_size_right >= 0) && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && @@ -640,7 +642,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -792,7 +794,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, + bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -950,7 +952,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1113,7 +1115,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1237,7 +1239,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1353,7 +1355,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f80bf933f7..fdf78fcef3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1654,7 +1654,7 @@ void fused_attn_fp8_bwd_impl( void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, @@ -1682,6 +1682,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); @@ -1712,8 +1713,8 @@ void fused_attn_fp8_fwd_impl_v1( bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, + window_size_left, + window_size_right, true, true, qkv_tensor_type, @@ -1786,10 +1787,12 @@ void fused_attn_fp8_fwd_impl_v1( int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t s_q_scale = (s_q + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 3) / 4) * 4; + int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; int64_t d_v_padded = ((d_v + 3) / 4) * 4; @@ -1804,7 +1807,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); @@ -1876,6 +1879,8 @@ void fused_attn_fp8_fwd_impl_v1( .set_generate_stats(true) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + sdpa_options.set_diagonal_band_right_bound(window_size_right); // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2055,7 +2060,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -2088,6 +2093,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); printf(">>>>>> scaling_mode: %d\n", scaling_mode); printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); @@ -2123,8 +2129,8 @@ void fused_attn_fp8_bwd_impl_v1( bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, + window_size_left, + window_size_right, true, false, qkv_tensor_type, @@ -2299,8 +2305,8 @@ void fused_attn_fp8_bwd_impl_v1( int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t s_q_scale = (s_q + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t s_q_padded = ((s_q + 3) / 4) * 4; + int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; int64_t d_v_padded = ((d_v + 3) / 4) * 4; int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; @@ -2408,6 +2414,9 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + // sdpa_backward_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2705,7 +2714,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, @@ -2794,7 +2803,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), @@ -2828,7 +2837,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, @@ -2897,7 +2906,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index f335bc3d85..22800b2aa2 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -18,7 +18,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, @@ -28,7 +28,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index d3c2e01814..873c101521 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -876,12 +876,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention for FP8" - ) - use_fused_attention = False - elif attention_dropout != 0.0: + if attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " "without dropout" @@ -1016,7 +1011,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_fused_attention and window_size is not None and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( "Disabling FusedAttention as only sub-backend %s does not support " @@ -2214,18 +2209,23 @@ def permute_x(f, x): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] - s_kv, d_kv = v.shape[-2:] + s_kv, d_v = v.shape[-2:] assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 - assert d_kv % 32 == 0 + assert d_v % 32 == 0 # need to check seqlens in THD % 128 == 0 q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # consider bhsd for now - grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) - q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + if d_qk == d_v: + grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors + else: + q_fp8 = qkv_quantizer(q) + k_fp8 = qkv_quantizer(k) + v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") From 2008bed824b69eb21650d146e18916f0d7f872e0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:13:40 -0800 Subject: [PATCH 14/59] remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/cast.cu | 1 - .../common/cast/dispatch/quantize.cuh | 1 - .../common/fused_attn/fused_attn.cpp | 11 ------ .../common/fused_attn/fused_attn_fp8.cu | 34 ------------------- .../dot_product_attention/backends.py | 4 --- .../attention/dot_product_attention/utils.py | 8 ----- .../pytorch/csrc/extensions/cast.cpp | 1 - .../pytorch/csrc/type_converters.cpp | 1 - .../pytorch/tensor/mxfp8_tensor.py | 2 -- 9 files changed, 63 deletions(-) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 624b0bfc7c..12d816f708 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,7 +30,6 @@ void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; - printf(">>>>>>>>>>>> nvte_group_quantize\n"); constexpr bool IS_ACT = false; dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9a6e9b01d6..b83df1dedf 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -375,7 +375,6 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, template void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - printf(">>>>>>>>>>>> group_quantize_fwd_helper\n"); using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6f343d90b2..d58fca70e9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -234,16 +234,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - printf(">>>>>> nvte_get_fused_attn_backend qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend q_dtype: %d, %d, %d\n", q_dtype, NVTEDType::kNVTEFloat8E4M3, NVTEDType::kNVTEFloat8E5M2); - printf(">>>>>> nvte_get_fused_attn_backend qkv_format: %d, %d, %d\n", qkv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend q_format: %d, %d, %d\n", q_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend kv_format: %d, %d, %d\n", kv_format, NVTE_QKV_Format::NVTE_BHSD, NVTE_QKV_Format::NVTE_BSHD); - printf(">>>>>> nvte_get_fused_attn_backend layout_group: %d, %d, %d\n", layout_group, NVTE_QKV_Layout_Group::NVTE_SD_SD_SD, NVTE_QKV_Layout_Group::NVTE_HD_HD_HD); - printf(">>>>>> nvte_get_fused_attn_backend cudnn_runtime_version: %d\n", cudnn_runtime_version); - printf(">>>>>> nvte_get_fused_attn_backend is_training: %d\n", is_training); - printf(">>>>>> nvte_get_fused_attn_backend bias_type: %d\n", bias_type); - printf(">>>>>> nvte_get_fused_attn_backend attn_mask_type: %d, %d, %d\n", attn_mask_type, NVTE_Mask_Type::NVTE_NO_MASK, NVTE_Mask_Type::NVTE_CAUSAL_MASK); if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && @@ -532,7 +522,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - printf(">>>>>> nvte_get_fused_attn_backend fused_attention_backend: %d\n", backend); return backend; } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fdf78fcef3..5c809c6050 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1681,14 +1681,6 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); - printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); - printf(">>>>>> scaling_mode: %d\n", scaling_mode); - printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); - printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); - printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); - printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -1779,9 +1771,6 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check - printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); - printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); - printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); int32_t block_size = 32; int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; @@ -1796,7 +1785,6 @@ void fused_attn_fp8_fwd_impl_v1( int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; int64_t d_v_padded = ((d_v + 3) / 4) * 4; - printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_v_padded}; @@ -1809,9 +1797,6 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -2049,7 +2034,6 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } - printf(">>>>>> mha_graph->execute(handle, variant_pack, workspace)\n"); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2092,16 +2076,6 @@ void fused_attn_fp8_bwd_impl_v1( dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); - printf(">>>>>> window_size_left: %d, window_size_right: %d\n", window_size_left, window_size_right); - printf(">>>>>> scaling_mode: %d\n", scaling_mode); - printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); - printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); - printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); - printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf(">>>>>> o_tensor_type: %d, %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); - printf(">>>>>> do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FLOAT, cudnn_frontend::DataType_t::BFLOAT16); - printf(">>>>>> dqkv_tensor_type: %d, %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2313,7 +2287,6 @@ void fused_attn_fp8_bwd_impl_v1( int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - printf(">>>>>> d_qk_scale: %d, d_v_scale: %d, s_q_scale: %d, s_kv_scale: %d, s_q_padded: %d, s_kv_padded: %d, d_qk_scale_padded: %d, d_v_scale_padded: %d, s_q_scale_padded: %d, s_kv_scale_padded: %d, d_qk_padded: %d, d_v_padded: %d\n", d_qk_scale, d_v_scale, s_q_scale, s_kv_scale, s_q_padded, s_kv_padded, d_qk_scale_padded, d_v_scale_padded, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, d_v_padded); // std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; // std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; // std::vector v_scale_dims = {b, hg, s_kv_padded, d_v_scale_padded}; @@ -2338,11 +2311,6 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); - printf(">>>>>> q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); - printf(">>>>>> k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") @@ -2733,7 +2701,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrScaleS = nullptr; void* devPtrDescaleS = nullptr; if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { - printf(">>>>>> input_Q is MXFP8_1D_SCALING\n"); devPtrQ = input_Q->data.dptr; devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; @@ -2745,7 +2712,6 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrO = output_O->data.dptr; devPtrAmaxO = output_O->amax.dptr; } else { - printf(">>>>>> input_Q is not MXFP8_1D_SCALING\n"); devPtrQ = input_Q->data.dptr; devPtrDescaleQ = input_Q->scale_inv.dptr; devPtrK = input_K->data.dptr; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 47f7e0f222..c0dca1b330 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1304,14 +1304,10 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - print(f"out_: {type(out_)} {out_.shape}") - print(f"is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}, fp8_recipe.float8_current_scaling(): {fp8_recipe.float8_current_scaling()}, _dpa_fp8_cs_o_in_f16: {_dpa_fp8_cs_o_in_f16}") if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): - print(f"dequantizing out_") if not is_output_fp8 or not is_bwd_fp8: out = out_.dequantize().view(out_.shape) else: - print(f"quantizing out_") if is_output_fp8 or ( is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 873c101521..12a75131aa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2187,7 +2187,6 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype if isinstance(qkv_quantizer, MXFP8Quantizer): - print(f"Combining and quantizing q, k, v: {q.shape}, {k.shape}, {v.shape}") def permute_x(f, x): x = x.contiguous() if not x.is_contiguous() else x @@ -2204,7 +2203,6 @@ def permute_x(f, x): if kv_format not in ["bhsd", "htd"]: k = permute_x(kv_format, k) v = permute_x(kv_format, v) - print(f">>>>>>>>>>>> Permuted shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" original_shapes = [x.shape for x in [q, k, v]] @@ -2216,7 +2214,6 @@ def permute_x(f, x): assert d_v % 32 == 0 # need to check seqlens in THD % 128 == 0 q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] - print(f">>>>>>>>>>>> Flattened shapes: q: {q.shape}, k: {k.shape}, v: {v.shape}") # consider bhsd for now if d_qk == d_v: @@ -2227,10 +2224,6 @@ def permute_x(f, x): k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_data.shape}, k_fp8: {k_fp8._rowwise_data.shape}, v_fp8: {v_fp8._rowwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._rowwise_scale_inv.shape}, k_fp8: {k_fp8._rowwise_scale_inv.shape}, v_fp8: {v_fp8._rowwise_scale_inv.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_data.shape}, k_fp8: {k_fp8._columnwise_data.shape}, v_fp8: {v_fp8._columnwise_data.shape}") - print(f">>>>>>>>>>>> q_fp8: {q_fp8._columnwise_scale_inv.shape}, k_fp8: {k_fp8._columnwise_scale_inv.shape}, v_fp8: {v_fp8._columnwise_scale_inv.shape}") return q_fp8, k_fp8, v_fp8, qkv_layout @@ -2292,7 +2285,6 @@ def combine_and_dequantize( des_nominal_dtype = src_nominal_dtype if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): - print(f"Combining and dequantizing q, k, v from MXFP8 to {des_nominal_dtype}") q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] return q, k, v diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 34565bcf44..9d3d6b901d 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -83,7 +83,6 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object quantize_grouped(const py::handle &input, py::handle &output) { using namespace transformer_engine::pytorch::detail; init_extension(); - printf(">>>>>>>>>>>> quantize_grouped\n"); const auto &grouped_input_tensor = GroupedTensorFromPyTorchGroupedTensor(input); const auto &grouped_output_tensor = GroupedTensorFromPyTorchGroupedTensor(output); NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index c17be6c855..8c9d7d7c16 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -212,7 +212,6 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { scaling_mode = ScalingModeFromQuantizer(quantizer); quantizer_dtype = quantizer.attr("dtype").cast(); with_gemm_swizzled_scales = quantizer.attr("optimize_for_gemm").cast(); - printf(">>>>>>>>>>>> GroupedTensorFromPyTorchGroupedTensor with_gemm_swizzled_scales: %d\n", with_gemm_swizzled_scales); } auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6c72d74531..41d6c87f2b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -75,7 +75,6 @@ def update_quantized( src = src.contiguous() # Launch cast kernel - print(f"MXFP8Quantizer.update_quantized: src: {src.shape}, dst: {dst.shape}") tex.quantize(src, self, dst, noop_flag) # Update FP8 dtype @@ -85,7 +84,6 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - print(f"MXFP8Quantizer.quantize_impl: tensor: {tensor.shape}") return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: From 239f58aec1b5c33e6b6e97ca4043c754066f241a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:14:42 -0800 Subject: [PATCH 15/59] remove leftover prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a8135391e3..70a68132ad 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -776,7 +776,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, switch (scaling_type) { case ScalingType::ROWWISE: { - printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::ROWWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -792,7 +791,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { - printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::COLWISE\n"); auto kernel = quantize_mxfp8_kernel; @@ -808,7 +806,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { - printf(">>>>>>>>>>>> quantize_mxfp8 ScalingType::BIDIMENSIONAL\n"); auto kernel = quantize_mxfp8_kernel; From f44a775706a249cef801b162d34f5ff0c9e8c5eb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:23:56 -0800 Subject: [PATCH 16/59] Revert "update FE" This reverts commit d9ff5662aa4b4b6267c77baf614aada6602fa133. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 +-- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8c7646c00d..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,8 +3,7 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git - branch = develop + url = https://github.com/NVIDIA/cudnn-frontend.git [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 4b4df2edcf..209a25fe89 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d +Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 From 965572bc571fe27d932f66ad74c026ee28d40adf Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:29:36 -0800 Subject: [PATCH 17/59] update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 ++- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 209a25fe89..4b4df2edcf 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 209a25fe898bb749c8605363a6431e26cb41b089 +Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d From 91025c74f8e5121bb9f195e562e3a18c3a00ba12 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Feb 2026 14:32:53 -0800 Subject: [PATCH 18/59] fix MLA O strides; add bottom_right_diagonal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 2 +- .../common/fused_attn/fused_attn.cpp | 12 +- .../common/fused_attn/fused_attn_fp8.cu | 488 +++++++++--------- .../common/fused_attn/fused_attn_fp8.h | 4 +- transformer_engine/common/fused_attn/utils.h | 10 +- .../dot_product_attention/backends.py | 8 + 6 files changed, 273 insertions(+), 251 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 05d76d96fe..ff3c7506e9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1788,7 +1788,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128), + "fp8_9": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d58fca70e9..ebda62568a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -631,7 +631,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -783,7 +783,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -941,7 +941,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1104,7 +1104,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1228,7 +1228,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1344,7 +1344,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 5c809c6050..b4aebea25d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1654,7 +1654,7 @@ void fused_attn_fp8_bwd_impl( void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, @@ -1662,6 +1662,7 @@ void fused_attn_fp8_fwd_impl_v1( cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -1681,6 +1682,13 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); + printf(">>>>>> scaling_mode: %d\n", scaling_mode); + printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); + printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); + printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); + printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -1707,7 +1715,7 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, window_size_left, window_size_right, - true, + bottom_right_diagonal, true, qkv_tensor_type, o_tensor_type, @@ -1762,6 +1770,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; + // Q, K, V, attn_scale std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -1770,102 +1779,112 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); // need to double check - - int32_t block_size = 32; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; - int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_scale = (s_q + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 3) / 4) * 4; - int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; - int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; - int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; - int64_t d_v_padded = ((d_v + 3) / 4) * 4; - std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; - std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; - std::vector v_scale_dims = {b, hg, s_kv_scale_padded, d_v_padded}; - std::vector q_scale_strides(4); - std::vector k_scale_strides(4); - std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - + NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_data_type(qkv_tensor_type)); + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_data_type(qkv_tensor_type)); + .set_name("K") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_data_type(qkv_tensor_type)); - + .set_name("V") + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - if (!is_mxfp8) { + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o + if (is_delayed_scaling || is_current_scaling) { descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - } else { + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_o"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } + } + if (is_mxfp8) { + int32_t block_size = 32; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; + int64_t s_q_scale = (s_q + block_size - 1) / block_size; + int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; + int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; + int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; + int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; + int64_t d_v_padded = ((d_v + 127) / 128) * 128; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") - .set_dim(q_scale_dims) + .set_dim({b, h, s_q_padded, d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") - .set_dim(k_scale_dims) + .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") - .set_dim(v_scale_dims) + .set_dim({b, hg, s_kv_scale_padded, d_v_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } - if (is_delayed_scaling) { - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); - } - if (is_current_scaling) { - scale_o = mha_graph->tensor(1.0f); - } - fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() .set_name("sdpa_fp8") .set_generate_stats(true) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); - sdpa_options.set_diagonal_band_right_bound(window_size_right); + + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -1924,9 +1943,13 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(o_stride).set_data_type(o_tensor_type); + printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); + printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); + printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); + printf(">>>>>> o_stride: %d, %d, %d, %d\n", o_stride[0], o_stride[1], o_stride[2], o_stride[3]); + O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2044,7 +2067,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -2057,6 +2080,7 @@ void fused_attn_fp8_bwd_impl_v1( cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -2105,7 +2129,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, window_size_left, window_size_right, - true, + bottom_right_diagonal, false, qkv_tensor_type, o_tensor_type, @@ -2116,25 +2140,25 @@ void fused_attn_fp8_bwd_impl_v1( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // q_t - std::shared_ptr, // k - std::shared_ptr, // k_t - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::shared_ptr, // Q + // std::shared_ptr, // Q_t + std::shared_ptr, // K + // std::shared_ptr, // K_t + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO - std::shared_ptr, // dO_t - std::shared_ptr, // dO_f16 + // std::shared_ptr, // dO_t + // std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q - std::shared_ptr, // descale_q_t + // std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k - std::shared_ptr, // descale_k_t + // std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO - std::shared_ptr, // descale_dO_t + // std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2175,16 +2199,16 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, descale_v; std::shared_ptr descale_s, descale_o; - std::shared_ptr descale_dP, descale_dO; + std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; std::shared_ptr scale_dQ, scale_dK, scale_dV; std::shared_ptr bias, dBias, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::shared_ptr q_t, k_t, dO_t, dO_f16, descale_q_t, descale_k_t, descale_dO_t; + // Q, K, V, O, dO, stats, attn_scale std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -2195,39 +2219,38 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_data_type(qkv_tensor_type)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() + K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_data_type(qkv_tensor_type)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() + V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_data_type(qkv_tensor_type)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() + O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d_qk}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d_qk}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_data_type(do_tensor_type)); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") + Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") .set_dim({b, h, s_q, 1}) .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -2235,25 +2258,25 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - if (!is_mxfp8) { + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Descale_dP, Scale_dP, Descale_o, Descale_dO, Scale_dQ, Scale_dK, Scale_dV + if (is_delayed_scaling || is_current_scaling) { descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + if (is_current_scaling && is_O_in_F16) { descale_o = mha_graph->tensor(1.0f); } else { descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); } descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - if (is_delayed_scaling) { scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); @@ -2264,74 +2287,73 @@ void fused_attn_fp8_bwd_impl_v1( scale_dK = mha_graph->tensor(1.0f); scale_dV = mha_graph->tensor(1.0f); } - } else { + } + if (is_mxfp8) { + // Q_t, K_t, dO_t, dO_f16 std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStrides(b, h, d_qk, s_kv, s_q, q_t_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, h, d_qk, s_kv, s_q, dO_t_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d_v, dO_t_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); + Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_stride) + .set_data_type(qkv_tensor_type)); + K_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_stride) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_stride) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type)); + // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t int32_t block_size = 32; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t s_q_padded = ((s_q + 127) / 128) * 128; + int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; int64_t s_q_scale = (s_q + block_size - 1) / block_size; int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_padded = ((s_q + 3) / 4) * 4; - int64_t s_kv_padded = ((s_kv + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 3) / 4) * 4; - int64_t d_v_padded = ((d_v + 3) / 4) * 4; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - // std::vector q_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; - // std::vector k_scale_dims = {b, hg, s_kv_padded, d_qk_scale_padded}; - // std::vector v_scale_dims = {b, hg, s_kv_padded, d_v_scale_padded}; - // std::vector q_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; - // std::vector k_t_scale_dims = {b, hg, s_kv_scale_padded, d_qk_padded}; - // // std::vector dO_scale_dims = {b, h, s_q_padded, d_qk_scale_padded}; - // // std::vector dO_t_scale_dims = {b, h, s_q_scale_padded, d_qk_padded}; + int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; + int64_t d_v_padded = ((d_v + 127) / 128) * 128; + int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; + int64_t d_v_scale = (d_v + block_size - 1) / block_size; + int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; + int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; std::vector q_scale_strides(4); - std::vector k_scale_strides(4); - std::vector v_scale_strides(4); std::vector q_t_scale_strides(4); + std::vector k_scale_strides(4); std::vector k_t_scale_strides(4); - // std::vector dO_scale_strides(4); - // std::vector dO_t_scale_strides(4); + std::vector v_scale_strides(4); + std::vector dO_scale_strides(4); + std::vector dO_t_scale_strides(4); generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, h, d_qk_padded, s_kv_scale_padded, s_q_scale_padded, q_t_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - q_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q_t") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_t_stride) - .set_data_type(qkv_tensor_type)); - k_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K_t") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_t_stride) - .set_data_type(qkv_tensor_type)); - dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_t") - .set_dim({b, h, s_q, d_qk}) - .set_stride(dO_t_stride) - .set_data_type(do_tensor_type)); - dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_f16") - .set_dim({b, h, s_q, d_qk}) - .set_stride(dO_t_stride) - .set_data_type(o_tensor_type)); + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_v_scale_padded, dO_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_v_padded, dO_t_scale_strides.data(), layout, + NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -2364,14 +2386,14 @@ void fused_attn_fp8_bwd_impl_v1( .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO") - .set_dim({b, h, s_q_padded, d_qk_scale_padded}) - .set_stride(q_scale_strides) + .set_dim({b, h, s_q_padded, d_v_scale_padded}) + .set_stride(dO_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO_t") - .set_dim({b, h, s_q_scale_padded, d_qk_padded}) - .set_stride(q_t_scale_strides) + .set_dim({b, h, s_q_scale_padded, d_v_padded}) + .set_stride(dO_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } @@ -2382,8 +2404,17 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); - // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + // fe::DiagonalAlignment_t const &diagonal_alignment = + // bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + // : fe::DiagonalAlignment_t::TOP_LEFT; + // sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + + // if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + // } + // if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + // } // sdpa_backward_options.set_alibi_mask(is_alibi); @@ -2434,14 +2465,15 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - // if (!is_mxfp8) { - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( - q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); - // } else { - // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( - // q, q_t, k, k_t, v, o, dO_f16, dO, dO_t, stats, descale_q, descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, descale_dO_t, - // sdpa_backward_options); + // if (is_delayed_scaling || is_current_scaling) { + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( + Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + // } + // if (is_mxfp8) { + // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( + // Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, + // descale_v, descale_dO, descale_dO_t, sdpa_backward_options); // } dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); @@ -2464,30 +2496,19 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); } - // dO->set_data_type(do_tensor_type); - // dQ->set_data_type(dqkv_tensor_type); - // dK->set_data_type(dqkv_tensor_type); - // dV->set_data_type(dqkv_tensor_type); - - std::tuple, // q - // std::shared_ptr, // q_t - std::shared_ptr, // k - // std::shared_ptr, // k_t - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO - // std::shared_ptr, // dO_t - // std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q - // std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k - // std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO - // std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2503,11 +2524,11 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dV std::shared_ptr> // amax_dP key_tensors_tuple = std::make_tuple( - q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); - auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); - // key_tensors_tuple = std::tuple_cat(key_tensors_tuple, mxfp8_tensors_tuple); + // auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + // : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2521,23 +2542,19 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple, mxfp8_tensors_tuple); + padding_tuple, dropout_tuple); + // padding_tuple, dropout_tuple, mxfp8_tensors_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto bprop_tuple = get_graph(sdpa_fp8_bprop_cache, descriptor); - // if (!is_mxfp8) { - // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, - // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); auto mha_graph = std::get<0>(bprop_tuple); - auto q = std::get<1>(bprop_tuple); - auto k = std::get<2>(bprop_tuple); - auto v = std::get<3>(bprop_tuple); - auto o = std::get<4>(bprop_tuple); - auto stats = std::get<5>(bprop_tuple); + auto Q = std::get<1>(bprop_tuple); + auto K = std::get<2>(bprop_tuple); + auto V = std::get<3>(bprop_tuple); + auto O = std::get<4>(bprop_tuple); + auto Stats = std::get<5>(bprop_tuple); auto dO = std::get<6>(bprop_tuple); auto attn_scale = std::get<7>(bprop_tuple); auto descale_q = std::get<8>(bprop_tuple); @@ -2565,19 +2582,14 @@ void fused_attn_fp8_bwd_impl_v1( auto seq_kv = std::get<30>(bprop_tuple); auto dropout_seed = std::get<31>(bprop_tuple); auto dropout_offset = std::get<32>(bprop_tuple); - // } else { // if (is_mxfp8) { - // auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, - // descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - // dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, q_t, k_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, - // dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); - auto q_t = std::get<33>(bprop_tuple); - auto k_t = std::get<34>(bprop_tuple); - auto dO_f16 = std::get<35>(bprop_tuple); - auto dO_t = std::get<36>(bprop_tuple); - auto descale_q_t = std::get<37>(bprop_tuple); - auto descale_k_t = std::get<38>(bprop_tuple); - auto descale_dO_t = std::get<39>(bprop_tuple); + // auto Q_t = std::get<33>(bprop_tuple); + // auto K_t = std::get<34>(bprop_tuple); + // auto dO_f16 = std::get<35>(bprop_tuple); + // auto dO_t = std::get<36>(bprop_tuple); + // auto descale_q_t = std::get<37>(bprop_tuple); + // auto descale_k_t = std::get<38>(bprop_tuple); + // auto descale_dO_t = std::get<39>(bprop_tuple); // } auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2594,11 +2606,11 @@ void fused_attn_fp8_bwd_impl_v1( // build variant pack std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrK}, - {v, devPtrV}, - {o, devPtrO}, - {stats, devPtrM}, + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {Stats, devPtrM}, {dO, devPtrdO}, {attn_scale, &scaling_factor}, {descale_q, devPtrDescaleQ}, @@ -2612,31 +2624,31 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dK, devPtrAmaxdK}, {amax_dV, devPtrAmaxdV}, }; - if (!is_mxfp8) { - variant_pack[descale_s] = devPtrDescaleS; - variant_pack[descale_dP] = devPtrDescaledP; - variant_pack[scale_s] = devPtrScaleS; - variant_pack[scale_dP] = devPtrScaledP; - variant_pack[amax_dP] = devPtrAmaxdP; - } else { - variant_pack[q_t] = devPtrQ_t; - variant_pack[k_t] = devPtrK_t; - variant_pack[dO_f16] = devPtrdO_f16; - variant_pack[dO_t] = devPtrdO_t; - variant_pack[descale_q_t] = devPtrDescaleQ_t; - variant_pack[descale_k_t] = devPtrDescaleK_t; - variant_pack[descale_dO] = devPtrDescaledO; - variant_pack[descale_dO_t] = devPtrDescaledO_t; + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + } + if (is_current_scaling && !is_O_in_F16) { + variant_pack[descale_o] = devPtrDescaleO; } - if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (is_current_scaling && !is_O_in_F16) { - variant_pack[descale_o] = devPtrDescaleO; - } + // if (is_mxfp8) { + // variant_pack[Q_t] = devPtrQ_t; + // variant_pack[K_t] = devPtrK_t; + // variant_pack[dO_f16] = devPtrdO_f16; + // variant_pack[dO_t] = devPtrdO_t; + // variant_pack[descale_q_t] = devPtrDescaleQ_t; + // variant_pack[descale_k_t] = devPtrDescaleK_t; + // variant_pack[descale_dO] = devPtrDescaledO; + // variant_pack[descale_dO_t] = devPtrDescaledO_t; + // } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -2682,7 +2694,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, const Tensor* input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, @@ -2769,7 +2781,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), @@ -2803,7 +2815,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor* input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, @@ -2872,7 +2884,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 22800b2aa2..bfadc0e870 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -18,7 +18,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, const Tensor *input_K, + NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, @@ -28,7 +28,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, const Tensor *input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index fdfc4abe82..ea3428855c 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -24,12 +24,14 @@ using namespace transformer_engine; enum NVTE_QKV_Matrix { NVTE_Q_Matrix = 0, // queries - NVTE_K_Matrix = 1, // keys - NVTE_K_Matrix_Transpose = 2, // keys transposed - NVTE_V_Matrix = 3, // values - NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_Q_Matrix_Transpose = 1, // queries transposed + NVTE_K_Matrix = 2, // keys + NVTE_K_Matrix_Transpose = 3, // keys transposed + NVTE_V_Matrix = 4, // values + NVTE_V_Matrix_Transpose = 5, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output + NVTE_O_Matrix_Transpose = 7, // final output transposed }; void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c0dca1b330..ac2c067ca8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1255,6 +1255,14 @@ def forward( dP_quantizer, ) + print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") + print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") + print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") + print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") + print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") + print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") + print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") # out_: # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 From d655e7e4e464585f11c3f68341ec71148f497537 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:11:05 -0800 Subject: [PATCH 19/59] attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 140 +++++++----------- .../dot_product_attention/backends.py | 56 ++++--- .../attention/dot_product_attention/utils.py | 21 ++- .../pytorch/cpp_extensions/fused_attn.py | 2 +- 4 files changed, 103 insertions(+), 116 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index b4aebea25d..504c387d57 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1682,13 +1682,6 @@ void fused_attn_fp8_fwd_impl_v1( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); - printf(">>>>>> b: %d, h: %d, h_g: %d, s_q: %d, s_kv: %d, d_qk: %d, d_v: %d\n", b, h, hg, s_q, s_kv, d_qk, d_v); - printf(">>>>>> scaling_mode: %d\n", scaling_mode); - printf(">>>>>> is_mxfp8: %d\n", is_mxfp8); - printf(">>>>>> is_current_scaling: %d\n", is_current_scaling); - printf(">>>>>> is_delayed_scaling: %d\n", is_delayed_scaling); - printf(">>>>>> qkv_tensor_type: %d, %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E8M0, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf(">>>>>> o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -1828,12 +1821,12 @@ void fused_attn_fp8_fwd_impl_v1( int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; + // int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; int64_t d_v_padded = ((d_v + 127) / 128) * 128; int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; + // int64_t d_v_scale = (d_v + block_size - 1) / block_size; int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + // int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); @@ -1843,10 +1836,6 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf(">>>>>> q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf(">>>>>> k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf(">>>>>> v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -1945,10 +1934,6 @@ void fused_attn_fp8_fwd_impl_v1( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - printf(">>>>>> q_stride: %d, %d, %d, %d\n", q_stride[0], q_stride[1], q_stride[2], q_stride[3]); - printf(">>>>>> k_stride: %d, %d, %d, %d\n", k_stride[0], k_stride[1], k_stride[2], k_stride[3]); - printf(">>>>>> v_stride: %d, %d, %d, %d\n", v_stride[0], v_stride[1], v_stride[2], v_stride[3]); - printf(">>>>>> o_stride: %d, %d, %d, %d\n", o_stride[0], o_stride[1], o_stride[2], o_stride[3]); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2141,24 +2126,24 @@ void fused_attn_fp8_bwd_impl_v1( using graph_and_tensors = std::tuple, std::shared_ptr, // Q - // std::shared_ptr, // Q_t + std::shared_ptr, // Q_t std::shared_ptr, // K - // std::shared_ptr, // K_t + std::shared_ptr, // K_t std::shared_ptr, // V std::shared_ptr, // O std::shared_ptr, // Stats std::shared_ptr, // dO - // std::shared_ptr, // dO_t - // std::shared_ptr, // dO_f16 + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q - // std::shared_ptr, // descale_q_t + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k - // std::shared_ptr, // descale_k_t + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO - // std::shared_ptr, // descale_dO_t + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2465,16 +2450,30 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - // if (is_delayed_scaling || is_current_scaling) { - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( + std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; + if (is_delayed_scaling || is_current_scaling) { + auto outputs = mha_graph->sdpa_fp8_backward( Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); - // } - // if (is_mxfp8) { - // auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV] = mha_graph->sdpa_fp8_backward( - // Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, - // descale_v, descale_dO, descale_dO_t, sdpa_backward_options); - // } + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + amax_dP = outputs[6]; + } + if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8_backward( + Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, + descale_v, descale_dO, descale_dO_t, sdpa_backward_options); + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + } dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride).set_data_type(dqkv_tensor_type); @@ -2527,8 +2526,8 @@ void fused_attn_fp8_bwd_impl_v1( Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); - // auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) - // : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2541,56 +2540,17 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, bias_tuple, padding_tuple, dropout_tuple); - // padding_tuple, dropout_tuple, mxfp8_tensors_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto bprop_tuple = get_graph(sdpa_fp8_bprop_cache, descriptor); - auto mha_graph = std::get<0>(bprop_tuple); - auto Q = std::get<1>(bprop_tuple); - auto K = std::get<2>(bprop_tuple); - auto V = std::get<3>(bprop_tuple); - auto O = std::get<4>(bprop_tuple); - auto Stats = std::get<5>(bprop_tuple); - auto dO = std::get<6>(bprop_tuple); - auto attn_scale = std::get<7>(bprop_tuple); - auto descale_q = std::get<8>(bprop_tuple); - auto descale_k = std::get<9>(bprop_tuple); - auto descale_v = std::get<10>(bprop_tuple); - auto descale_o = std::get<11>(bprop_tuple); - auto descale_dO = std::get<12>(bprop_tuple); - auto descale_s = std::get<13>(bprop_tuple); - auto descale_dP = std::get<14>(bprop_tuple); - auto scale_s = std::get<15>(bprop_tuple); - auto scale_dQ = std::get<16>(bprop_tuple); - auto scale_dK = std::get<17>(bprop_tuple); - auto scale_dV = std::get<18>(bprop_tuple); - auto scale_dP = std::get<19>(bprop_tuple); - auto dQ = std::get<20>(bprop_tuple); - auto dK = std::get<21>(bprop_tuple); - auto dV = std::get<22>(bprop_tuple); - auto amax_dQ = std::get<23>(bprop_tuple); - auto amax_dK = std::get<24>(bprop_tuple); - auto amax_dV = std::get<25>(bprop_tuple); - auto amax_dP = std::get<26>(bprop_tuple); - auto bias = std::get<27>(bprop_tuple); - auto dBias = std::get<28>(bprop_tuple); - auto seq_q = std::get<29>(bprop_tuple); - auto seq_kv = std::get<30>(bprop_tuple); - auto dropout_seed = std::get<31>(bprop_tuple); - auto dropout_offset = std::get<32>(bprop_tuple); - // if (is_mxfp8) { - // auto Q_t = std::get<33>(bprop_tuple); - // auto K_t = std::get<34>(bprop_tuple); - // auto dO_f16 = std::get<35>(bprop_tuple); - // auto dO_t = std::get<36>(bprop_tuple); - // auto descale_q_t = std::get<37>(bprop_tuple); - // auto descale_k_t = std::get<38>(bprop_tuple); - // auto descale_dO_t = std::get<39>(bprop_tuple); - // } + auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, + bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + auto plan_workspace_size = mha_graph->get_workspace_size(); // Exit to request upper level API to allocate memory if needed @@ -2639,16 +2599,16 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - // if (is_mxfp8) { - // variant_pack[Q_t] = devPtrQ_t; - // variant_pack[K_t] = devPtrK_t; - // variant_pack[dO_f16] = devPtrdO_f16; - // variant_pack[dO_t] = devPtrdO_t; - // variant_pack[descale_q_t] = devPtrDescaleQ_t; - // variant_pack[descale_k_t] = devPtrDescaleK_t; - // variant_pack[descale_dO] = devPtrDescaledO; - // variant_pack[descale_dO_t] = devPtrDescaledO_t; - // } + if (is_mxfp8) { + variant_pack[Q_t] = devPtrQ_t; + variant_pack[K_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + variant_pack[descale_dO] = devPtrDescaledO; + variant_pack[descale_dO_t] = devPtrDescaledO_t; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ac2c067ca8..98c48d0e3a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -203,8 +203,11 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou def backward(ctx, grad1, grad2, grad3): # pylint: disable=missing-function-docstring if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: - dt_fp8 = ctx.quantizer(grad1) - tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + if ctx.quantizer is not None: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + else: + tensors = grad1, grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] dq_fp8, dk_fp8, dv_fp8, ctx.qkv_layout = combine_and_quantize( @@ -213,6 +216,10 @@ def backward(ctx, grad1, grad2, grad3): tensors = combine_and_dequantize( ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) + if isinstance(ctx.quantizer, MXFP8Quantizer): + assert ctx.qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -1341,7 +1348,7 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance(QKV_quantizer, MXFP8Quantizer): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) else: @@ -1481,13 +1488,30 @@ def forward( @staticmethod def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - + print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") + print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + if ctx.original_qkv_layout != ctx.qkv_layout: + original_qkv_format = ctx.original_qkv_layout.split("_")[0] + new_qkv_format = ctx.qkv_layout.split("_")[0] + perm = [] + for i in original_qkv_format: + perm.append(new_qkv_format.find(i)) + d_out = d_out.permute(*perm).contiguous() + print(f"d_out: {d_out.shape}, {type(d_out)}") # d_out is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): + print(f"before dO_quantizer: {type(d_out)}, {d_out.shape}") + d_out_f16 = d_out + ctx.dO_quantizer.optimize_for_gemm = True d_out = ctx.dO_quantizer(d_out) + print(f"after dO_quantizer: {type(d_out)}, {d_out.shape}") if not ctx.use_FAv2_bwd: - d_out._data = d_out._data.contiguous() + if isinstance(ctx.dO_quantizer, MXFP8Quantizer): + d_out._rowwise_data = d_out._rowwise_data.contiguous() + d_out._columnwise_data = d_out._columnwise_data.contiguous() + else: + d_out._data = d_out._data.contiguous() elif not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1549,14 +1573,6 @@ def backward(ctx, d_out, *_args): # FP8 attention: torch.float16 or torch.bfloat16 dqkv_nominal_dtype = ctx.nominal_dtype - if ctx.original_qkv_layout != ctx.qkv_layout: - original_qkv_format = ctx.original_qkv_layout.split("_")[0] - new_qkv_format = ctx.qkv_layout.split("_")[0] - perm = [] - for i in original_qkv_format: - perm.append(new_qkv_format.find(i)) - d_out = d_out.permute(*perm).contiguous() - if ctx.fp8: # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1595,10 +1611,16 @@ def backward(ctx, d_out, *_args): out_ = out_fp8 if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: out_ = out - if ctx.fp8_recipe.mxfp8_block_scaling(): + if ctx.fp8_recipe.mxfp8(): out_ = out - aux_ctx_tensors.append(d_out) - + aux_ctx_tensors.append(d_out_f16) + print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") + print(f"shapes: {q_fp8.shape}, {k_fp8.shape}, {v_fp8.shape}, {out_.shape}, {d_out_fp8.shape}, {[x.shape for x in aux_ctx_tensors]}") + for i in [q_fp8, k_fp8, v_fp8, out_, d_out_fp8, *aux_ctx_tensors]: + if isinstance(i, MXFP8Tensor): + print(f"xxxx: {i._with_gemm_swizzled_scales}") + else: + print(f"xxxx: {i.shape}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1621,7 +1643,7 @@ def backward(ctx, d_out, *_args): ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, + ctx.original_qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 12a75131aa..a9ce96c8c9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2091,10 +2091,11 @@ def get_attention_quantizers(fp8, quantizers): if not fp8: return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + # QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=True) O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer.set_usage(rowwise=True, columnwise=False) + # O_quantizer.optimize_for_gemm = True if isinstance(QKV_quantizer, MXFP8Quantizer): QKV_quantizer.optimize_for_gemm = True # QKV_quantizer.internal = False @@ -2105,14 +2106,18 @@ def get_attention_quantizers(fp8, quantizers): S_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer.interal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=True) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer.internal = True - dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) - dP_quantizer.interal = True + dO_quantizer.set_usage(rowwise=True, columnwise=True) + dO_quantizer.internal = False + # dO_quantizer.optimize_for_gemm = True + if isinstance(dO_quantizer, MXFP8Quantizer): + dP_quantizer = None + else: + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 09953440e9..8f77a8a7fd 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -496,7 +496,7 @@ def fused_attn_bwd( dqkv_dtype is not None ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( - len(aux_ctx_tensors) == 3 + len(aux_ctx_tensors) >= 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." output_tensors = tex.fused_attn_bwd( From a4ab691cc4eda08582d083ad0169ef2682adde44 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Feb 2026 18:04:31 -0800 Subject: [PATCH 20/59] fix get_quantizers; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/backends.py | 88 +++++++++---------- .../dot_product_attention/context_parallel.py | 4 +- .../attention/dot_product_attention/utils.py | 47 +++++----- 3 files changed, 69 insertions(+), 70 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 98c48d0e3a..4d1481de79 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -452,19 +452,15 @@ def forward( scale /= self.layer_number if fp8: + # get fp8 recipe for DPA + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) - # disable swizzle for MXFP8Quantizer - for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: - if isinstance(q, MXFP8Quantizer): - q.optimize_for_gemm = False - q.internal = False # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=S_quantizer.dtype, device="cuda" @@ -472,6 +468,11 @@ def forward( dP_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=dP_quantizer.dtype, device="cuda" ) + # disable swizzle for MXFP8Quantizer + for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: + if isinstance(q, MXFP8Quantizer): + q.optimize_for_gemm = False + q.internal = False # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 @@ -1229,7 +1230,7 @@ def forward( # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # get nominal data type for out @@ -1488,31 +1489,32 @@ def forward( @staticmethod def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring + print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") - print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling if ctx.original_qkv_layout != ctx.qkv_layout: + print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") original_qkv_format = ctx.original_qkv_layout.split("_")[0] new_qkv_format = ctx.qkv_layout.split("_")[0] perm = [] for i in original_qkv_format: perm.append(new_qkv_format.find(i)) d_out = d_out.permute(*perm).contiguous() - print(f"d_out: {d_out.shape}, {type(d_out)}") - # d_out is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): - print(f"before dO_quantizer: {type(d_out)}, {d_out.shape}") - d_out_f16 = d_out - ctx.dO_quantizer.optimize_for_gemm = True - d_out = ctx.dO_quantizer(d_out) - print(f"after dO_quantizer: {type(d_out)}, {d_out.shape}") - if not ctx.use_FAv2_bwd: - if isinstance(ctx.dO_quantizer, MXFP8Quantizer): - d_out._rowwise_data = d_out._rowwise_data.contiguous() - d_out._columnwise_data = d_out._columnwise_data.contiguous() - else: - d_out._data = d_out._data.contiguous() - elif not ctx.use_FAv2_bwd: + print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") + + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + d_out_fp8 = None + if ctx.fp8: + print(f"d_out before quantizer: {d_out.shape}, {type(d_out)}") + if isinstance(d_out, QuantizedTensorStorage): + d_out_fp8 = d_out + else: + d_out_fp8 = ctx.dO_quantizer(d_out) + print(f"d_out after quantizer: {d_out_fp8.shape}, {type(d_out_fp8)}") + if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( q_fp8, @@ -1574,14 +1576,6 @@ def backward(ctx, d_out, *_args): dqkv_nominal_dtype = ctx.nominal_dtype if ctx.fp8: - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - if ctx.is_output_fp8: - d_out_fp8 = d_out - else: - d_out_fp8 = ctx.dO_quantizer(d_out) - # print quantizers print_quantizers( "FusedAttnFunc.backward >> before: ", @@ -1613,14 +1607,9 @@ def backward(ctx, d_out, *_args): out_ = out if ctx.fp8_recipe.mxfp8(): out_ = out - aux_ctx_tensors.append(d_out_f16) + aux_ctx_tensors.append(d_out) print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") print(f"shapes: {q_fp8.shape}, {k_fp8.shape}, {v_fp8.shape}, {out_.shape}, {d_out_fp8.shape}, {[x.shape for x in aux_ctx_tensors]}") - for i in [q_fp8, k_fp8, v_fp8, out_, d_out_fp8, *aux_ctx_tensors]: - if isinstance(i, MXFP8Tensor): - print(f"xxxx: {i._with_gemm_swizzled_scales}") - else: - print(f"xxxx: {i.shape}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1632,7 +1621,7 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, + dqkv_te_dtype, # could we remove this? aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1643,7 +1632,7 @@ def backward(ctx, d_out, *_args): ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.original_qkv_layout, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1652,10 +1641,17 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) + if ctx.original_qkv_layout != ctx.qkv_layout: + original_qkv_format = ctx.original_qkv_layout.split("_")[0] + new_qkv_format = ctx.qkv_layout.split("_")[0] + perm = [] + for i in new_qkv_format: + perm.append(original_qkv_format.find(i)) + dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_quantized_tensor = isinstance(dq_, QuantizedTensor) + is_quantized_tensor = isinstance(dq_, QuantizedTensorStorage) if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( @@ -1665,7 +1661,7 @@ def backward(ctx, d_out, *_args): dv_, src_nominal_dtype=dq_.dtype, ) - if not is_float8tensor and ctx.is_input_fp8: + if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 dq, dk, dv, _ = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer @@ -1982,7 +1978,7 @@ def forward( " with FP8!" ) if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + all_quantizers = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) for q in all_quantizers: if isinstance(q, Float8CurrentScalingQuantizer): q.with_amax_reduction = True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 244f24111d..34af861604 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1356,7 +1356,7 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) q_f16 = None q_fp8, k_fp8, v_fp8 = (None, None, None) @@ -3394,7 +3394,7 @@ def forward( max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) q_fp8, k_fp8, v_fp8 = (None, None, None) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index a9ce96c8c9..26a2bda08b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2086,38 +2086,41 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers): +def get_attention_quantizers(fp8, fp8_recipe, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - # QKV_quantizer.internal = True - QKV_quantizer.set_usage(rowwise=True, columnwise=True) + QKV_quantizer.internal = True + QKV_quantizer.set_usage(rowwise=True, columnwise=False) O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer.set_usage(rowwise=True, columnwise=False) - # O_quantizer.optimize_for_gemm = True - if isinstance(QKV_quantizer, MXFP8Quantizer): - QKV_quantizer.optimize_for_gemm = True - # QKV_quantizer.internal = False - S_quantizer = None - else: - S_quantizer = quantizers["scaling_fwd"][META_S] - S_quantizer.internal = True - S_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = False - dQKV_quantizer.set_usage(rowwise=True, columnwise=True) + dQKV_quantizer.interal = True + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=True) - dO_quantizer.internal = False - # dO_quantizer.optimize_for_gemm = True - if isinstance(dO_quantizer, MXFP8Quantizer): + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True + + if fp8_recipe.mxfp8(): + QKV_quantizer.columnwise = True + QKV_quantizer.optimize_for_gemm = True + O_quantizer.columnwise = True + O_quantizer.optimize_for_gemm = True + S_quantizer = None + dQKV_quantizer.columnwise = True + dQKV_quantizer.optimize_for_gemm = True + dO_quantizer.columnwise = True + dO_quantizer.optimize_for_gemm = True dP_quantizer = None - else: - dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) - dP_quantizer.interal = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer From a85070dbc38f4ce749d4d0f246a9fe29b928112d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Feb 2026 17:03:04 -0800 Subject: [PATCH 21/59] fix fprop; add o_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 54 +++++++++++++++++-- .../common/fused_attn/fused_attn_fp8.cu | 23 ++++---- .../common/fused_attn/fused_attn_fp8.h | 2 +- transformer_engine/common/fused_attn/utils.cu | 31 +++++++++++ transformer_engine/common/fused_attn/utils.h | 3 ++ .../include/transformer_engine/fused_attn.h | 13 ++++- .../dot_product_attention/backends.py | 43 +++++++-------- .../attention/dot_product_attention/utils.py | 30 ++++++----- .../pytorch/cpp_extensions/fused_attn.py | 4 ++ transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 20 ++++--- 11 files changed, 162 insertions(+), 63 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ebda62568a..79a8417bdc 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -208,6 +208,52 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } } +// map one NVTE_QKV_Format to another +std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format) { + std::vector dst_shape(src_shape.size()); + size_t b=0, h=0, s=0, d=0, t=0; + switch (src_format) { + case NVTE_QKV_Format::NVTE_BSHD: + b = src_shape[0]; + s = src_shape[1]; + h = src_shape[2]; + d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_SBHD: + s = src_shape[0]; + b = src_shape[1]; + h = src_shape[2]; + d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_BHSD: + b = src_shape[0]; + h = src_shape[1]; + s = src_shape[2]; + d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_THD: + t = src_shape[0]; + h = src_shape[1]; + d = src_shape[2]; + break; + } + switch (dst_format) { + case NVTE_QKV_Format::NVTE_BSHD: + dst_shape = {b, s, h, d}; + break; + case NVTE_QKV_Format::NVTE_SBHD: + dst_shape = {s, b, h, d}; + break; + case NVTE_QKV_Format::NVTE_BHSD: + dst_shape = {b, h, s, d}; + break; + case NVTE_QKV_Format::NVTE_THD: + dst_shape = {t, h, d}; + break; + } + return dst_shape; +} + // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, @@ -631,7 +677,7 @@ void nvte_fused_attn_fwd_qkvpacked( Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, + qkv_layout, qkv_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -941,7 +987,7 @@ void nvte_fused_attn_fwd_kvpacked( Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, + dropout, qkv_layout, q_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -1125,7 +1171,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { @@ -1228,7 +1274,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + dropout, qkv_layout, o_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 504c387d57..4a9af8b4b3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1653,7 +1653,7 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, @@ -1702,7 +1702,7 @@ void fused_attn_fp8_fwd_impl_v1( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, @@ -1767,11 +1767,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -1830,11 +1830,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -1932,8 +1932,7 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2653,7 +2652,7 @@ void fused_attn_fp8_bwd_impl_v1( void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, @@ -2741,7 +2740,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index bfadc0e870..548a41a561 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -17,7 +17,7 @@ namespace transformer_engine { void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index c6d6386fb7..0ea5d6aa7f 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -17,6 +17,37 @@ namespace fused_attn { using namespace transformer_engine; +// get matrix strides based on matrix type +void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strideA, NVTE_QKV_Format format) { +constexpr int batch_dim_idx = 0; +constexpr int head_dim_idx = 1; +constexpr int seqlen_dim_idx = 2; +constexpr int hidden_dim_idx = 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strideA[batch_dim_idx] = s * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strideA[batch_dim_idx] = h * s * d; + strideA[head_dim_idx] = s * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + break; + } +} + // get matrix strides based on matrix type void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index ea3428855c..88535f61c9 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -34,6 +34,9 @@ enum NVTE_QKV_Matrix { NVTE_O_Matrix_Transpose = 7, // final output transposed }; +void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strideA, NVTE_QKV_Format format); + void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 204d8f3d5a..883c5a6e61 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -193,6 +193,16 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \brief Convert one NVTE_QKV_Format to another. + * + * \param[in] src_shape The source shape. + * \param[in] src_format The source format. + * \param[in] dst_format The destination format. + * + * \return The destination shape. + */ + std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format); + /*! \brief Get fused attention backend based on input parameters. * * \param[in] is_training Whether the model is in training mode. @@ -563,6 +573,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -581,7 +592,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4d1481de79..2838eaa5eb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1213,21 +1213,21 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + original_qkv_layout = qkv_layout + _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) + # input types are inferred from the real data while output types are controlled by fp8_output # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) - # whether bwd kernel in FP8: + # whether fwd kernel will be run in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel will be run in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - # save original qkv_layout - original_qkv_layout = qkv_layout - # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) @@ -1249,7 +1249,7 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(original_qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) # print quantizers print_quantizers( @@ -1263,14 +1263,6 @@ def forward( dP_quantizer, ) - print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") - print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") - print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") - print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") - print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") - print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") - print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") - print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") # out_: # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1298,6 +1290,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1307,20 +1300,21 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - if original_qkv_layout != qkv_layout: - original_qkv_format = original_qkv_layout.split("_")[0] - new_qkv_format = qkv_layout.split("_")[0] - perm = [] - for i in new_qkv_format: - perm.append(original_qkv_format.find(i)) - out_ = out_.permute(*perm).contiguous() + print(f"out_.shape: {out_.shape}, type(out_): {type(out_)}") + # if original_qkv_layout != qkv_layout: + # original_qkv_format = original_qkv_layout.split("_")[0] + # new_qkv_format = qkv_layout.split("_")[0] + # perm = [] + # for i in new_qkv_format: + # perm.append(original_qkv_format.find(i)) + # out_ = out_.permute(*perm).contiguous() # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - if isinstance(out_, Float8Tensor) or isinstance(out_, MXFP8Tensor): + if isinstance(out_, QuantizedTensorStorage): if not is_output_fp8 or not is_bwd_fp8: out = out_.dequantize().view(out_.shape) else: @@ -1382,6 +1376,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 26a2bda08b..176434d883 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2094,33 +2094,37 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=False) - O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.internal = False + O_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer.internal = True + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + dP_quantizer.set_usage(rowwise=True, columnwise=False) + + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) if fp8_recipe.mxfp8(): - QKV_quantizer.columnwise = True + QKV_quantizer.columnwise_usage = True QKV_quantizer.optimize_for_gemm = True - O_quantizer.columnwise = True - O_quantizer.optimize_for_gemm = True S_quantizer = None - dQKV_quantizer.columnwise = True - dQKV_quantizer.optimize_for_gemm = True - dO_quantizer.columnwise = True + O_quantizer.columnwise_usage = True + + dO_quantizer.columnwise_usage = True dO_quantizer.optimize_for_gemm = True dP_quantizer = None + dQKV_quantizer.columnwise_usage = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 8f77a8a7fd..629046aa1c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -135,6 +135,7 @@ def fused_attn_fwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -204,6 +205,8 @@ def fused_attn_fwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -307,6 +310,7 @@ def fused_attn_fwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e9531573d6..7d2d002111 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -85,7 +85,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 0d7a842ce1..30415b4373 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -114,7 +114,7 @@ std::pair quantizer_helper(py::handle quantizer, // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -150,8 +150,9 @@ std::vector fused_attn_fwd( std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; + o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape = nvte_convert_qkv_format(o_shape_tmp, nvte_get_q_format(qkv_layout), o_format); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -164,7 +165,7 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { @@ -172,7 +173,7 @@ std::vector fused_attn_fwd( } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (o_format == NVTE_QKV_Format::NVTE_THD) { te_O.zero_(at::cuda::getCurrentCUDAStream()); } } else { @@ -251,7 +252,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -311,7 +312,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -367,6 +368,11 @@ std::vector fused_attn_bwd( at::Tensor dQ, dK, dV, dQKV, dKV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + // DType dqkv_type = DType::kNumTypes; + // if (!dqkv_quantizer.is_none()) { + // dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); + // } + // printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); From 8909b35da8ff35bd09bbec184afceaa749f068a6 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Feb 2026 19:00:07 -0800 Subject: [PATCH 22/59] attempt at bwd with o_format/d_out_format/dqkv_layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 8 +-- .../common/fused_attn/fused_attn_fp8.cu | 61 +++++++++-------- .../common/fused_attn/fused_attn_fp8.h | 2 +- .../include/transformer_engine/fused_attn.h | 5 +- .../dot_product_attention/backends.py | 68 ++++++++++--------- .../attention/dot_product_attention/utils.py | 25 +++---- .../pytorch/cpp_extensions/fused_attn.py | 29 ++++++-- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/attention.cpp | 26 +++---- 9 files changed, 127 insertions(+), 101 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 79a8417bdc..1e9673fff7 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -828,7 +828,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, qkv_format, qkv_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, @@ -1150,7 +1150,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, + qkv_layout, q_format, q_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1293,7 +1293,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, @@ -1390,7 +1390,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 4a9af8b4b3..2afe979f04 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2050,7 +2050,7 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, + float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, @@ -2107,7 +2107,7 @@ void fused_attn_fp8_bwd_impl_v1( scaling_factor, true, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, @@ -2197,14 +2197,13 @@ void fused_attn_fp8_bwd_impl_v1( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -2277,12 +2276,11 @@ void fused_attn_fp8_bwd_impl_v1( std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, h, s_q, s_kv, d_v, dO_t_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); + generateMatrixStridesWithFormat(b, h, d_v, s_q, dO_t_stride.data(), d_out_format); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2324,20 +2322,18 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_scale_strides(4); std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), layout, + generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), layout, + generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_v_scale_padded, dO_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_v_padded, dO_t_scale_strides.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix_Transpose); + generateMatrixStridesWithFormat(b, h, s_q_padded, d_v_scale_padded, dO_scale_strides.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, d_v_padded, s_q_scale_padded, dO_t_scale_strides.data(), d_out_format); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -2473,9 +2469,18 @@ void fused_attn_fp8_bwd_impl_v1( amax_dK = outputs[4]; amax_dV = outputs[5]; } - dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride).set_data_type(dqkv_tensor_type); - dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride).set_data_type(dqkv_tensor_type); - dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride).set_data_type(dqkv_tensor_type); + std::vector dq_stride(4); + std::vector dk_stride(4); + std::vector dv_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d_qk, dq_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, dk_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_v, dv_stride.data(), dqkv_layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(dq_stride).set_data_type(dqkv_tensor_type); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(dk_stride).set_data_type(dqkv_tensor_type); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(dv_stride).set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2773,7 +2778,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, @@ -2839,11 +2844,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, @@ -2853,7 +2858,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 548a41a561..215b5dd92a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -27,7 +27,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 883c5a6e61..d866cab702 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -647,6 +647,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. + * \param[in] d_out_format Output gradient's format. + * \param[in] dqkv_layout QKV gradient tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -666,7 +669,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2838eaa5eb..1a4b34f4fa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1213,6 +1213,7 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + # save qkv_layout and get output format original_qkv_layout = qkv_layout _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) @@ -1300,14 +1301,6 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - print(f"out_.shape: {out_.shape}, type(out_): {type(out_)}") - # if original_qkv_layout != qkv_layout: - # original_qkv_format = original_qkv_layout.split("_")[0] - # new_qkv_format = qkv_layout.split("_")[0] - # perm = [] - # for i in new_qkv_format: - # perm.append(original_qkv_format.find(i)) - # out_ = out_.permute(*perm).contiguous() # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1463,9 +1456,10 @@ def forward( else: ctx.qkv_layout = qkv_layout else: - ctx.original_qkv_layout = original_qkv_layout ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type @@ -1486,29 +1480,32 @@ def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") - # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling - if ctx.original_qkv_layout != ctx.qkv_layout: - print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") - print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") - original_qkv_format = ctx.original_qkv_layout.split("_")[0] - new_qkv_format = ctx.qkv_layout.split("_")[0] - perm = [] - for i in original_qkv_format: - perm.append(new_qkv_format.find(i)) - d_out = d_out.permute(*perm).contiguous() - print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") + # # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling + # if ctx.original_qkv_layout != ctx.qkv_layout: + # print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") + # print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") + # original_qkv_format = ctx.original_qkv_layout.split("_")[0] + # new_qkv_format = ctx.qkv_layout.split("_")[0] + # perm = [] + # for i in original_qkv_format: + # perm.append(new_qkv_format.find(i)) + # d_out = d_out.permute(*perm).contiguous() + # print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 d_out_fp8 = None + d_out_format = ctx.o_format if ctx.fp8: print(f"d_out before quantizer: {d_out.shape}, {type(d_out)}") + if ctx.fp8_recipe.mxfp8(): + d_out, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, d_out) if isinstance(d_out, QuantizedTensorStorage): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - print(f"d_out after quantizer: {d_out_fp8.shape}, {type(d_out_fp8)}") + print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8.shape}, {type(d_out)}, {type(d_out_fp8)}") if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1583,8 +1580,8 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # get tex.DType for dq, dk, dv data - dqkv_te_dtype = d_out_fp8._fp8_dtype + # # get tex.DType for dq, dk, dv data + # dqkv_te_dtype = d_out_fp8._fp8_dtype # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 @@ -1616,7 +1613,7 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, # could we remove this? + # dqkv_te_dtype, # could we remove this? aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1628,6 +1625,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1636,13 +1636,15 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) - if ctx.original_qkv_layout != ctx.qkv_layout: - original_qkv_format = ctx.original_qkv_layout.split("_")[0] - new_qkv_format = ctx.qkv_layout.split("_")[0] - perm = [] - for i in new_qkv_format: - perm.append(original_qkv_format.find(i)) - dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] + print(f"dq_.shape: {dq_.shape}, dk_.shape: {dk_.shape}, dv_.shape: {dv_.shape}") + print(f"types: {type(dq_)}, {type(dk_)}, {type(dv_)}") + # if ctx.original_qkv_layout != ctx.qkv_layout: + # original_qkv_format = ctx.original_qkv_layout.split("_")[0] + # new_qkv_format = ctx.qkv_layout.split("_")[0] + # perm = [] + # for i in new_qkv_format: + # perm.append(original_qkv_format.find(i)) + # dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ @@ -1676,7 +1678,7 @@ def backward(ctx, d_out, *_args): else: if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) - dqkv_te_dtype = TE_DType[d_out.dtype] + # dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, @@ -1689,7 +1691,7 @@ def backward(ctx, d_out, *_args): out, d_out, dqkv_nominal_dtype, - dqkv_te_dtype, + # dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 176434d883..379e056b54 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2190,6 +2190,16 @@ def print_quantizers( f"{label} >> {names[i]:14s}: {type_str}" ) +def permute_to_grouped_tensor(src_format, tensor): + """Permute tensor to bhsd or htd format for grouped quantization in MXFP8BlockScaling. src_format ={bshd, sbhd, thd}""" + if src_format in ["bhsd", "htd"]: + return tensor, src_format + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + dim_s_or_t = src_format.find("s") if 's' in src_format else src_format.find("t") + dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] + perm = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] + tensor = tensor.permute(*perm).contiguous() + return tensor, "bhsd" if src_format != "thd" else "htd" def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" @@ -2199,22 +2209,13 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype if isinstance(qkv_quantizer, MXFP8Quantizer): - - def permute_x(f, x): - x = x.contiguous() if not x.is_contiguous() else x - dim_s_dim_t = f.find("s") if 's' in f else f.find("t") - dim_others = [i for i in range(len(x.shape)) if i != dim_s_dim_t] - perm = [*dim_others[:-1], dim_s_dim_t, dim_others[-1]] - x = x.permute(*perm).contiguous() - return x - # bs3hd, sb3hd, etc -> bshd_bshd_bhsd -> bhsd_bhsd_bhsd # t3hd, etc -> thd_thd_thd -> htd_htd_htd if q_format not in ["bhsd", "htd"]: - q = permute_x(q_format, q) + q, _ = permute_to_grouped_tensor(q_format, q) if kv_format not in ["bhsd", "htd"]: - k = permute_x(kv_format, k) - v = permute_x(kv_format, v) + k, _ = permute_to_grouped_tensor(kv_format, k) + v, _ = permute_to_grouped_tensor(kv_format, v) qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" original_shapes = [x.shape for x in [q, k, v]] diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 629046aa1c..7a756ead1c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -364,7 +364,7 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, fake_dtype: torch.dtype, - dqkv_dtype: tex.DType, + # dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, @@ -376,6 +376,9 @@ def fused_attn_bwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", + d_out_format: str = "sbhd", + dqkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -414,8 +417,8 @@ def fused_attn_bwd( fake_dtype : tex.DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype - dqkv_dtype : tex.DType - data type of dQ, dK and dV; in tex.DType, not torch.dtype + # dqkv_dtype : tex.DType + # data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -442,6 +445,15 @@ def fused_attn_bwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} + d_out_format : str, default = "sbhd" + format of dO; {"sbhd", "bshd", "thd"} + dqkv_layout : str, default = "sbh3d" + layout of dQ, dK and dV; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -496,9 +508,9 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert ( - dqkv_dtype is not None - ), "dqkv_dtype is required as an input for FP8 fused attention backward." + # assert ( + # dqkv_dtype is not None + # ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( len(aux_ctx_tensors) >= 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." @@ -510,6 +522,9 @@ def fused_attn_bwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[d_out_format], + QKVLayout[dqkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -524,7 +539,7 @@ def fused_attn_bwd( o, d_o, fake_dtype, - dqkv_dtype, + # dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7d2d002111..795f50f672 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -99,11 +99,11 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 30415b4373..0cb0ae0a06 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -327,11 +327,11 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -366,18 +366,18 @@ std::vector fused_attn_bwd( const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); std::vector tmp_shape; - // DType dqkv_type = DType::kNumTypes; - // if (!dqkv_quantizer.is_none()) { - // dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); - // } - // printf(">>>>>> dQKV_type: %d\n", dqkv_type); + DType dqkv_type = fake_dtype_te; + if (!dqkv_quantizer.is_none()) { + dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); + } + printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } @@ -460,7 +460,7 @@ std::vector fused_attn_bwd( // construct NVTE tensors if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD)) { if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -473,7 +473,7 @@ std::vector fused_attn_bwd( } } } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); @@ -560,7 +560,7 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -577,7 +577,7 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); From 90a636c9e132b4288644e7c1cc94cdc9d3c673dc Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:08:51 -0800 Subject: [PATCH 23/59] fix dtype/o_format/etc in bwd calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 4 ++-- transformer_engine/common/fused_attn/fused_attn.cpp | 7 +++++-- .../attention/dot_product_attention/backends.py | 13 +++++++++++-- .../pytorch/csrc/extensions/attention.cpp | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index ff3c7506e9..7f353a9483 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1788,8 +1788,8 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), + "fp8_9": ModelConfig(2, 2048, 16, 128),#, attn_mask_type="causal"), + "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128), #, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1e9673fff7..ed0971627e 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -447,12 +447,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -1345,6 +1346,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, cuda_graph, deterministic); + printf("Q_type: %d, KV_type: %d, qkv_layout: %d, bias_type: %d, attn_mask_type: %d, softmax_type: %d, dropout: %f, h_q: %d, h_kv: %d, max_seqlen_q: %d, max_seqlen_kv: %d, d_qk: %d, d_v: %d, window_size_left: %d, window_size_right: %d, deterministic: %d\n", Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, deterministic); + printf("fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 1a4b34f4fa..d49e7f2365 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1457,6 +1457,8 @@ def forward( ctx.qkv_layout = qkv_layout else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout ctx.o_format = o_format ctx.dqkv_layout = original_qkv_layout @@ -1505,7 +1507,7 @@ def backward(ctx, d_out, *_args): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8.shape}, {type(d_out)}, {type(d_out_fp8)}") + print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8._rowwise_data.shape}, {type(d_out)}, {type(d_out_fp8)}") if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1600,8 +1602,12 @@ def backward(ctx, d_out, *_args): if ctx.fp8_recipe.mxfp8(): out_ = out aux_ctx_tensors.append(d_out) + print(f"q_fp8._with_gemm_swizzled_scales: {q_fp8._with_gemm_swizzled_scales}") + print(f"k_fp8._with_gemm_swizzled_scales: {k_fp8._with_gemm_swizzled_scales}") + print(f"v_fp8._with_gemm_swizzled_scales: {v_fp8._with_gemm_swizzled_scales}") + print(f"d_out_fp8._with_gemm_swizzled_scales: {d_out_fp8._with_gemm_swizzled_scales}") print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") - print(f"shapes: {q_fp8.shape}, {k_fp8.shape}, {v_fp8.shape}, {out_.shape}, {d_out_fp8.shape}, {[x.shape for x in aux_ctx_tensors]}") + print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1703,6 +1709,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 0cb0ae0a06..fc870c4591 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -370,7 +370,7 @@ std::vector fused_attn_bwd( std::vector tmp_shape; DType dqkv_type = fake_dtype_te; if (!dqkv_quantizer.is_none()) { - dqkv_type = dqkv_quantizer.attr("fp8_dtype").cast(); + dqkv_type = dqkv_quantizer.attr("dtype").cast(); } printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); From 8c72deaa83ee8a2816fa4059c7f569e998e9e29e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:09:34 -0800 Subject: [PATCH 24/59] fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 138 ++++---- transformer_engine/common/fused_attn/utils.cu | 36 +- transformer_engine/common/fused_attn/utils.h | 310 +++++++++++++++++- 3 files changed, 390 insertions(+), 94 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 2afe979f04..9de1fdeabc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1767,11 +1767,11 @@ void fused_attn_fp8_fwd_impl_v1( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -1814,28 +1814,15 @@ void fused_attn_fp8_fwd_impl_v1( } } if (is_mxfp8) { - int32_t block_size = 32; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; - int64_t s_q_scale = (s_q + block_size - 1) / block_size; - int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; - int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - // int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; - int64_t d_v_padded = ((d_v + 127) / 128) * 128; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - // int64_t d_v_scale = (d_v + block_size - 1) / block_size; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - // int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); std::vector q_scale_strides(4); std::vector k_scale_strides(4); std::vector v_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_v_padded, v_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, true); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -1932,7 +1919,7 @@ void fused_attn_fp8_fwd_impl_v1( } std::vector o_stride(4); - generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) @@ -2197,13 +2184,13 @@ void fused_attn_fp8_bwd_impl_v1( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -2272,15 +2259,41 @@ void fused_attn_fp8_bwd_impl_v1( } } if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); // Q_t, K_t, dO_t, dO_f16 std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_t_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_t_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStridesWithFormat(b, h, d_v, s_q, dO_t_stride.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, true); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, true); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, true); + printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], q_t_stride[3]); + printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], k_t_stride[3]); + printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], dO_t_stride[3]); + printf("qkv_tensor_type: %d\n", qkv_tensor_type); + printf("o_tensor_type: %d\n", o_tensor_type); + printf("do_tensor_type: %d\n", do_tensor_type); + printf("dqkv_tensor_type: %d\n", dqkv_tensor_type); + printf("qkv_layout: %d\n", qkv_layout); + printf("o_format: %d\n", o_format); + printf("d_out_format: %d\n", d_out_format); + printf("dqkv_layout: %d\n", dqkv_layout); + printf("b: %d\n", b); + printf("h: %d\n", h); + printf("hg: %d\n", hg); + printf("s_q: %d\n", s_q); + printf("s_kv: %d\n", s_kv); + printf("d_qk: %d\n", d_qk); + printf("d_v: %d\n", d_v); + printf("is_delayed_scaling: %d\n", is_delayed_scaling); + printf("is_current_scaling: %d\n", is_current_scaling); + printf("is_O_in_F16: %d\n", is_O_in_F16); + printf("is_mxfp8: %d\n", is_mxfp8); + printf("is_causal: %d\n", is_causal); + printf("is_padding: %d\n", is_padding); + printf("is_dropout: %d\n", is_dropout); + printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2302,19 +2315,19 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride(o_stride) .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t - int32_t block_size = 32; - int64_t s_q_padded = ((s_q + 127) / 128) * 128; - int64_t s_kv_padded = ((s_kv + 127) / 128) * 128; - int64_t s_q_scale = (s_q + block_size - 1) / block_size; - int64_t s_kv_scale = (s_kv + block_size - 1) / block_size; - int64_t s_q_scale_padded = ((s_q_scale + 3) / 4) * 4; - int64_t s_kv_scale_padded = ((s_kv_scale + 3) / 4) * 4; - int64_t d_qk_padded = ((d_qk + 127) / 128) * 128; - int64_t d_v_padded = ((d_v + 127) / 128) * 128; - int64_t d_qk_scale = (d_qk + block_size - 1) / block_size; - int64_t d_v_scale = (d_v + block_size - 1) / block_size; - int64_t d_qk_scale_padded = ((d_qk_scale + 3) / 4) * 4; - int64_t d_v_scale_padded = ((d_v_scale + 3) / 4) * 4; + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + printf("s_q_padded: %d\n", padded.s_q_padded); + printf("s_kv_padded: %d\n", padded.s_kv_padded); + printf("s_q_scale: %d\n", padded.s_q_scale); + printf("s_kv_scale: %d\n", padded.s_kv_scale); + printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); + printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); + printf("d_qk_padded: %d\n", padded.d_qk_padded); + printf("d_v_padded: %d\n", padded.d_v_padded); + printf("d_qk_scale: %d\n", padded.d_qk_scale); + printf("d_v_scale: %d\n", padded.d_v_scale); + printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); + printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2322,18 +2335,20 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_scale_strides(4); std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); - generateMatrixStrides(b, h, s_q_padded, s_kv_padded, d_qk_scale_padded, q_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, h, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, q_t_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_qk_scale_padded, k_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q_scale_padded, s_kv_scale_padded, d_qk_padded, k_t_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - generateMatrixStrides(b, hg, s_q_padded, s_kv_padded, d_v_scale_padded, v_scale_strides.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStridesWithFormat(b, h, s_q_padded, d_v_scale_padded, dO_scale_strides.data(), d_out_format); - generateMatrixStridesWithFormat(b, h, d_v_padded, s_q_scale_padded, dO_t_scale_strides.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, v_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, dO_scale_strides.data(), d_out_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, true); + printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); + printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); + printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); + printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); + printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); + printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], dO_scale_strides[2], dO_scale_strides[3]); + printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, s_q_padded, d_qk_scale_padded}) @@ -2472,12 +2487,15 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dq_stride(4); std::vector dk_stride(4); std::vector dv_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, dq_stride.data(), dqkv_layout, + generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, dq_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, dk_stride.data(), dqkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dk_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, dv_stride.data(), dqkv_layout, + generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); + printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); + printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(dq_stride).set_data_type(dqkv_tensor_type); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(dk_stride).set_data_type(dqkv_tensor_type); dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(dv_stride).set_data_type(dqkv_tensor_type); diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 0ea5d6aa7f..8a9399e830 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -17,37 +17,6 @@ namespace fused_attn { using namespace transformer_engine; -// get matrix strides based on matrix type -void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, - int64_t *strideA, NVTE_QKV_Format format) { -constexpr int batch_dim_idx = 0; -constexpr int head_dim_idx = 1; -constexpr int seqlen_dim_idx = 2; -constexpr int hidden_dim_idx = 3; - - switch (format) { - case NVTE_QKV_Format::NVTE_BSHD: - case NVTE_QKV_Format::NVTE_THD: - strideA[batch_dim_idx] = s * h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = h * d; - strideA[hidden_dim_idx] = 1; - break; - case NVTE_QKV_Format::NVTE_SBHD: - strideA[batch_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[seqlen_dim_idx] = b * h * d; - strideA[hidden_dim_idx] = 1; - break; - case NVTE_QKV_Format::NVTE_BHSD: - strideA[batch_dim_idx] = h * s * d; - strideA[head_dim_idx] = s * d; - strideA[seqlen_dim_idx] = d; - strideA[hidden_dim_idx] = 1; - break; - } -} - // get matrix strides based on matrix type void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { @@ -343,6 +312,11 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[head_dim_idx] = s_kv * d; strideA[seqlen_transpose_dim_idx] = d; strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_transpose_dim_idx] = d; + strideA[hidden_transpose_dim_idx] = 1; } break; } diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 88535f61c9..f0b947c379 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -31,11 +31,315 @@ enum NVTE_QKV_Matrix { NVTE_V_Matrix_Transpose = 5, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output - NVTE_O_Matrix_Transpose = 7, // final output transposed }; -void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, - int64_t *strideA, NVTE_QKV_Format format); +// Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) +struct MXFP8PaddedSizes { + int64_t s_q_padded; + int64_t s_kv_padded; + int64_t s_q_scale; + int64_t s_kv_scale; + int64_t s_q_scale_padded; + int64_t s_kv_scale_padded; + int64_t d_qk_padded; + int64_t d_v_padded; + int64_t d_qk_scale; + int64_t d_v_scale; + int64_t d_qk_scale_padded; + int64_t d_v_scale_padded; +}; + +// Pad s and d for MXFP8 layout +inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { + constexpr int64_t block_size = 32; + MXFP8PaddedSizes p; + p.s_q_padded = ((s_q + 127) / 128) * 128; + p.s_kv_padded = ((s_kv + 127) / 128) * 128; + p.s_q_scale = (s_q + block_size - 1) / block_size; + p.s_kv_scale = (s_kv + block_size - 1) / block_size; + p.s_q_scale_padded = ((p.s_q_scale + 3) / 4) * 4; + p.s_kv_scale_padded = ((p.s_kv_scale + 3) / 4) * 4; + p.d_qk_padded = ((d_qk + 127) / 128) * 128; + p.d_v_padded = ((d_v + 127) / 128) * 128; + p.d_qk_scale = (d_qk + block_size - 1) / block_size; + p.d_v_scale = (d_v + block_size - 1) / block_size; + p.d_qk_scale_padded = ((p.d_qk_scale + 3) / 4) * 4; + p.d_v_scale_padded = ((p.d_v_scale + 3) / 4) * 4; + return p; +} + +// Get matrix strides for a 4D tensor [batch, head, seqlen, hidden] given a QKV format. +// strideA must point to at least 4 int64_t elements. +inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strides, NVTE_QKV_Format format, bool transpose) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strides[batch_dim_idx] = s * h * d; + strides[head_dim_idx] = d; + strides[seqlen_dim_idx] = h * d; + strides[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strides[batch_dim_idx] = h * d; + strides[head_dim_idx] = d; + strides[seqlen_dim_idx] = b * h * d; + strides[hidden_dim_idx] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strides[batch_dim_idx] = h * s * d; + strides[head_dim_idx] = s * d; + strides[seqlen_dim_idx] = d; + strides[hidden_dim_idx] = 1; + break; + default: + NVTE_CHECK(false, "Invalid format."); + break; + } +} + +// get matrix strides based on matrix type +inline void generateMatrixStrides_v1( + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t *strides, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) +{ + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + bool transpose = + (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; + + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; + } + NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * s_q * d_qk; + strides[head_dim_idx] = s_q * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_qk; + strides[head_dim_idx] = s_kv * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_v; + strides[head_dim_idx] = s_kv * d_v; + strides[seqlen_dim_idx] = d_v; + strides[hidden_dim_idx] = 1; + } + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } + + if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { + strides[seqlen_kv_dim_idx] = 1; + strides[seqlen_q_dim_idx] = s_kv; + strides[head_dim_idx] = s_q * s_kv; + strides[batch_dim_idx] = h * s_q * s_kv; + } +} void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); From 5f23eddf1e7fdbd7c26526d45155a877012f0841 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:16:01 -0800 Subject: [PATCH 25/59] fix upon last commit for paddedsizes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 9de1fdeabc..f13eef3a66 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1825,19 +1825,19 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, true); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") - .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") - .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") - .set_dim({b, hg, s_kv_scale_padded, d_v_padded}) + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); @@ -2351,43 +2351,43 @@ void fused_attn_fp8_bwd_impl_v1( printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") - .set_dim({b, h, s_q_padded, d_qk_scale_padded}) + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q_t") - .set_dim({b, h, s_q_scale_padded, d_qk_padded}) + .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) .set_stride(q_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") - .set_dim({b, hg, s_kv_padded, d_qk_scale_padded}) + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_k_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k_t") - .set_dim({b, hg, s_kv_scale_padded, d_qk_padded}) + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) .set_stride(k_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") - .set_dim({b, hg, s_kv_padded, d_v_scale_padded}) + .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO") - .set_dim({b, h, s_q_padded, d_v_scale_padded}) + .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) .set_stride(dO_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO_t") - .set_dim({b, h, s_q_scale_padded, d_v_padded}) + .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) .set_stride(dO_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); From 18c55801b592d620a6a8c4ea02828f06b0f8d3fd Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:29:30 -0800 Subject: [PATCH 26/59] add mxfp8 env var Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/dot_product_attention.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 55553d30be..f7699340e6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -139,6 +139,11 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | MXFP8 | Pass NVFP4 to autocast(); | +| | | Attention MXFP8 reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ """ _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} @@ -673,6 +678,15 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs From 68476456981a2533b0e18f20340403ee0f50f08d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:30:03 -0800 Subject: [PATCH 27/59] disable FA for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 379e056b54..e8a4170cf3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -475,6 +475,9 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -482,6 +485,10 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False + if use_flash_attention_3 and fp8_recipe.mxfp8(): + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for MXFP8") + use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() @@ -489,9 +496,6 @@ def get_attention_backend( if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") @@ -603,9 +607,9 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - if use_flash_attention_2 and FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 as it does not support MLA.") - use_flash_attention_2 = False + # if use_flash_attention_2 and FlashAttentionUtils.is_installed: + # logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + # use_flash_attention_2 = False qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": From c5a98d5e9dcbba2f84a9328236b5a1a47616d97d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 20 Feb 2026 19:21:29 -0800 Subject: [PATCH 28/59] add mha test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 7f353a9483..f0e70280bf 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1816,7 +1816,7 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_mha_fp8_vs_f16( dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode ): @@ -1841,6 +1841,12 @@ def test_mha_fp8_vs_f16( fp8_dpa=True, fp8_mha=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=True, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( From 7e61ecdd2dd585fbb96476385f04ab660a361980 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 23 Feb 2026 16:26:05 -0800 Subject: [PATCH 29/59] attempt at bwd; force determinism; fix shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 24 +++--- .../common/fused_attn/fused_attn.cpp | 86 +++++++++++++------ .../common/fused_attn/fused_attn_fp8.cu | 42 +++++---- .../common/fused_attn/fused_attn_fp8.h | 2 +- .../include/transformer_engine/fused_attn.h | 10 ++- .../dot_product_attention/backends.py | 8 +- .../dot_product_attention/context_parallel.py | 8 +- .../attention/dot_product_attention/utils.py | 16 +++- .../pytorch/csrc/extensions/attention.cpp | 76 ++++++++++++---- transformer_engine/pytorch/csrc/quantizer.cpp | 3 + 10 files changed, 187 insertions(+), 88 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f0e70280bf..9922d93a77 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2084,7 +2084,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal # config.dropout_p = 0.1 os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" + # os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability @@ -2238,16 +2238,18 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if is_training: for i, _ in enumerate(fused_attn_bwd_f16): logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - compare_and_assert( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - True, - ) + print(f"fused_attn_bwd_fp8[{i}].max(): {fused_attn_bwd_fp8[i].max()}, fused_attn_bwd_f16[{i}].max(): {fused_attn_bwd_f16[i].max()}") + print(f"fused_attn_bwd_fp8[{i}].min(): {fused_attn_bwd_fp8[i].min()}, fused_attn_bwd_f16[{i}].min(): {fused_attn_bwd_f16[i].min()}") + # compare_and_assert( + # fused_attn_bwd_fp8[i], + # fused_attn_bwd_f16[i], + # f"fused_attn_bwd_fp8[{i}]", + # f"fused_attn_bwd_f16[{i}]", + # atol, + # rtol, + # rmse_tol, + # True, + # ) os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ed0971627e..0886118451 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -209,49 +209,83 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } // map one NVTE_QKV_Format to another -std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format) { - std::vector dst_shape(src_shape.size()); - size_t b=0, h=0, s=0, d=0, t=0; +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, + size_t *b, size_t *h, size_t *s, size_t *d, size_t *t) { + printf("src_format: %d, src_shape: %d, %d, %d, %d, %d\n", src_format, src_shape[0], src_shape[1], src_shape[2], src_shape[3], src_shape[4]); + size_t _b=0, _h=0, _s=0, _d=0, _t=0; switch (src_format) { case NVTE_QKV_Format::NVTE_BSHD: - b = src_shape[0]; - s = src_shape[1]; - h = src_shape[2]; - d = src_shape[3]; + _b = src_shape[0]; + _s = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; break; case NVTE_QKV_Format::NVTE_SBHD: - s = src_shape[0]; - b = src_shape[1]; - h = src_shape[2]; - d = src_shape[3]; + _s = src_shape[0]; + _b = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; break; case NVTE_QKV_Format::NVTE_BHSD: - b = src_shape[0]; - h = src_shape[1]; - s = src_shape[2]; - d = src_shape[3]; + _b = src_shape[0]; + _h = src_shape[1]; + _s = src_shape[2]; + _d = src_shape[3]; break; case NVTE_QKV_Format::NVTE_THD: - t = src_shape[0]; - h = src_shape[1]; - d = src_shape[2]; + _t = src_shape[0]; + _h = src_shape[1]; + _d = src_shape[2]; + break; + default: + NVTE_ERROR("src_format not supported!"); break; } switch (dst_format) { case NVTE_QKV_Format::NVTE_BSHD: - dst_shape = {b, s, h, d}; + dst_shape[0] = _b; + dst_shape[1] = _s; + dst_shape[2] = _h; + dst_shape[3] = _d; break; case NVTE_QKV_Format::NVTE_SBHD: - dst_shape = {s, b, h, d}; + dst_shape[0] = _s; + dst_shape[1] = _b; + dst_shape[2] = _h; + dst_shape[3] = _d; break; case NVTE_QKV_Format::NVTE_BHSD: - dst_shape = {b, h, s, d}; + dst_shape[0] = _b; + dst_shape[1] = _h; + dst_shape[2] = _s; + dst_shape[3] = _d; break; case NVTE_QKV_Format::NVTE_THD: - dst_shape = {t, h, d}; + dst_shape[0] = _t; + dst_shape[1] = _h; + dst_shape[2] = _d; break; + default: + NVTE_ERROR("dst_format not supported!"); + break; + } + printf("dst_format: %d, dst_shape: %d, %d, %d, %d, %d\n", dst_format, dst_shape[0], dst_shape[1], dst_shape[2], dst_shape[3], dst_shape[4]); + + if (b != nullptr) { + *b = _b; + } + if (h != nullptr) { + *h = _h; + } + if (s != nullptr) { + *s = _s; + } + if (d != nullptr) { + *d = _d; + } + if (t != nullptr) { + *t = _t; } - return dst_shape; } // select a backend for fused attention @@ -830,7 +864,7 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, d, attn_scale, dropout, qkv_layout, qkv_format, qkv_format, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); @@ -1151,7 +1185,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, attn_scale, dropout, - qkv_layout, q_format, q_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, input_O, + qkv_layout, q_format, q_format, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -1393,7 +1427,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f13eef3a66..57b250f6af 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1822,7 +1822,7 @@ void fused_attn_fp8_fwd_impl_v1( auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, false); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) @@ -2038,7 +2038,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -2101,7 +2101,7 @@ void fused_attn_fp8_bwd_impl_v1( window_size_left, window_size_right, bottom_right_diagonal, - false, + deterministic, qkv_tensor_type, o_tensor_type, do_tensor_type, @@ -2265,20 +2265,20 @@ void fused_attn_fp8_bwd_impl_v1( std::vector q_t_stride(4); std::vector k_t_stride(4); std::vector dO_t_stride(4); - generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, true); - generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, true); - generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, true); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], q_t_stride[3]); printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], k_t_stride[3]); printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], dO_t_stride[3]); - printf("qkv_tensor_type: %d\n", qkv_tensor_type); - printf("o_tensor_type: %d\n", o_tensor_type); - printf("do_tensor_type: %d\n", do_tensor_type); - printf("dqkv_tensor_type: %d\n", dqkv_tensor_type); - printf("qkv_layout: %d\n", qkv_layout); - printf("o_format: %d\n", o_format); - printf("d_out_format: %d\n", d_out_format); - printf("dqkv_layout: %d\n", dqkv_layout); + printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); + printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); + printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); + printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); printf("b: %d\n", b); printf("h: %d\n", h); printf("hg: %d\n", hg); @@ -2336,12 +2336,12 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, true); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, true); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, false); generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, v_scale_strides.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, dO_scale_strides.data(), d_out_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, true); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); @@ -2431,6 +2431,10 @@ void fused_attn_fp8_bwd_impl_v1( // } // } + if (cudnn_runtime_version >= 92100) { + sdpa_backward_options.set_deterministic_algorithm(deterministic); + } + if (is_padding) { seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("seq_q") @@ -2797,7 +2801,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, @@ -2866,7 +2870,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 215b5dd92a..9683974a26 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -28,7 +28,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index d866cab702..64eb385584 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -195,13 +195,19 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Convert one NVTE_QKV_Format to another. * - * \param[in] src_shape The source shape. * \param[in] src_format The source format. + * \param[in] src_shape The source shape. * \param[in] dst_format The destination format. + * \param[in,out] dst_shape The destination shape. + * \param[in,out] b The batch size. + * \param[in,out] h The number of heads. + * \param[in,out] s The sequence length. + * \param[in,out] d The head dimension. + * \param[in,out] t The time dimension. * * \return The destination shape. */ - std::vector nvte_convert_qkv_format(std::vector src_shape, NVTE_QKV_Format src_format, NVTE_QKV_Format dst_format); + void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, size_t *h, size_t *s, size_t *d, size_t *t); /*! \brief Get fused attention backend based on input parameters. * diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d49e7f2365..0ef2dede76 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1301,7 +1301,7 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - + print(f"out_.shape: {out_.shape}, {type(out_)}, qkv_layout: {qkv_layout}, o_format: {o_format}") # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1606,8 +1606,10 @@ def backward(ctx, d_out, *_args): print(f"k_fp8._with_gemm_swizzled_scales: {k_fp8._with_gemm_swizzled_scales}") print(f"v_fp8._with_gemm_swizzled_scales: {v_fp8._with_gemm_swizzled_scales}") print(f"d_out_fp8._with_gemm_swizzled_scales: {d_out_fp8._with_gemm_swizzled_scales}") - print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") - print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") + # print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") + # print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") + print(f"out_.shape: {out_.shape}, d_out_fp8.shape: {d_out_fp8._rowwise_data.shape}, d_out_fp8.columnwise_data.shape: {d_out_fp8._columnwise_data.shape}, d_out.shape: {d_out.shape}") + print(f"out_.stride: {out_.stride()}, d_out_fp8.rowwise_data.stride: {d_out_fp8._rowwise_data.stride()}, d_out_fp8.columnwise_data.stride: {d_out_fp8._columnwise_data.stride()}, d_out.stride: {d_out.stride()}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 34af861604..22d8378598 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3403,10 +3403,10 @@ def forward( fused_attn_backend = FusedAttnBackend["FP8"] if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + elif not isinstance(QKV_quantizer, MXFP8Quantizer): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["o_quantizer"] = O_quantizer diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index e8a4170cf3..16410e8e00 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1069,10 +1069,10 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_fused_attention = False fused_attention_backend = None - if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons with FP8") - use_fused_attention = False - fused_attention_backend = None + # if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: + # logger.debug("Disabling FusedAttention for determinism reasons with FP8") + # use_fused_attention = False + # fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and is_training @@ -2241,6 +2241,14 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] + print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") + print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") + print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") + print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") + print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") + print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") + print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") + print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") return q_fp8, k_fp8, v_fp8, qkv_layout diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index fc870c4591..d1b1baed30 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -92,6 +92,13 @@ std::pair quantizer_helper(py::handle quantizer, "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + printf("in quantizer_helper\n"); + printf("create_hp_tensor_for_cs: %d\n", create_hp_tensor_for_cs); + printf("data.has_value(): %d\n", data.has_value()); + printf("shape: %d, %d, %d, %d, %d\n", shape[0], shape[1], shape[2], shape[3], shape[4]); + printf("dtype: %d\n", dtype); + printf("quantizer: %p\n", quantizer.ptr()); + // MXFP8 auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); if (create_hp_tensor_for_cs) { @@ -99,6 +106,7 @@ std::pair quantizer_helper(py::handle quantizer, std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); } else { + printf("in quantizer_helper, creating unquantized tensor with amax\n"); std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); } } else { @@ -152,7 +160,11 @@ std::vector fused_attn_fwd( std::vector v_shape = convertShape(te_V.shape()); auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; - auto o_shape = nvte_convert_qkv_format(o_shape_tmp, nvte_get_q_format(qkv_layout), o_format); + auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; + size_t b=0, h=0, s=0, d=0, t=0; + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); + printf("b: %d, h: %d, s: %d, d: %d, t: %d\n", b, h, s, d, t); + printf("o_shape: %d, %d, %d, %d, %d\n", o_shape[0], o_shape[1], o_shape[2], o_shape[3]); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -163,8 +175,8 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; + // auto h = q_shape[q_shape.size() - 2]; + // auto d = q_shape[q_shape.size() - 1]; if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -316,7 +328,14 @@ std::vector fused_attn_fwd( softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); - + printf("after nvte_fused_attn_fwd\n"); + float *amax_cpu; + amax_cpu = (float *)malloc(sizeof(float)); + *amax_cpu=0.0; + cudaMemcpy(amax_cpu, te_O.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("O amax_cpu: %f\n", *amax_cpu); + // printf("py_O.amax(): %f\n", py_O.attr("amax").cast().cpu().item()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -360,14 +379,17 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d_qk = q_shape[q_shape.size() - 1]; + // auto h_q = q_shape[q_shape.size() - 2]; + // auto h_kv = k_shape[k_shape.size() - 2]; + // auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + size_t b=0, h_q=0, h_kv=0, s_q=0, s_kv=0, d_qk=0, d_v=0, t_q=0, t_kv=0; + std::vector dQ_shape(4), dK_shape(4), dV_shape(4); + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); - std::vector tmp_shape; DType dqkv_type = fake_dtype_te; if (!dqkv_quantizer.is_none()) { dqkv_type = dqkv_quantizer.attr("dtype").cast(); @@ -380,10 +402,12 @@ std::vector fused_attn_bwd( if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); + std::vector tmp_shape; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -400,7 +424,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -414,9 +438,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -429,9 +453,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -442,11 +466,12 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); dK = torch::empty(tmp_shape, options); - tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + tmp_shape = std::vector(dV_shape.begin(), dV_shape.end()); dV = torch::empty(tmp_shape, options); break; default: @@ -582,6 +607,21 @@ std::vector fused_attn_bwd( at::cuda::getCurrentCUDAStream()); }); + float *amax_cpu; + amax_cpu = (float *)malloc(sizeof(float)); + *amax_cpu=0.0; + cudaMemcpy(amax_cpu, te_dQ.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("dQ amax_cpu: %f\n", *amax_cpu); + cudaMemcpy(amax_cpu, te_dK.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("dK amax_cpu: %f\n", *amax_cpu); + cudaMemcpy(amax_cpu, te_dV.amax(), sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("dV amax_cpu: %f\n", *amax_cpu); + // printf("py_dQ.amax(): %f\n", py_dQ.attr("amax").cast().cpu().item()); + // printf("py_dK.amax(): %f\n", py_dK.attr("amax").cast().cpu().item()); + // printf("py_dV.amax(): %f\n", py_dV.attr("amax").cast().cpu().item()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 20820143b0..d5c3ea00b4 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -949,6 +949,9 @@ std::pair MXFP8Quantizer::create_unquantized_tensor_w TensorWrapper out_cpp = std::move(out.first); py::object out_py = std::move(out.second); out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); + printf("after MXFP8Quantizer::create_unquantized_tensor_with_amax\n"); + printf("amax_ptr: %p\n", amax_tensor.data_ptr()); + printf("out_cpp.amax(): %f\n", amax_tensor.cpu().item()); return {std::move(out_cpp), std::move(out_py)}; } From 6d468da04cbc32aa06b6b22a7efc180a5f9159c4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:21:36 -0800 Subject: [PATCH 30/59] remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 4 --- .../dot_product_attention/backends.py | 33 ------------------ .../attention/dot_product_attention/utils.py | 8 ----- .../pytorch/csrc/extensions/attention.cpp | 34 ------------------- transformer_engine/pytorch/csrc/quantizer.cpp | 3 -- 5 files changed, 82 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 0886118451..557cb3eea1 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -211,7 +211,6 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // map one NVTE_QKV_Format to another void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, size_t *h, size_t *s, size_t *d, size_t *t) { - printf("src_format: %d, src_shape: %d, %d, %d, %d, %d\n", src_format, src_shape[0], src_shape[1], src_shape[2], src_shape[3], src_shape[4]); size_t _b=0, _h=0, _s=0, _d=0, _t=0; switch (src_format) { case NVTE_QKV_Format::NVTE_BSHD: @@ -269,7 +268,6 @@ void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src NVTE_ERROR("dst_format not supported!"); break; } - printf("dst_format: %d, dst_shape: %d, %d, %d, %d, %d\n", dst_format, dst_shape[0], dst_shape[1], dst_shape[2], dst_shape[3], dst_shape[4]); if (b != nullptr) { *b = _b; @@ -1380,8 +1378,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, cuda_graph, deterministic); - printf("Q_type: %d, KV_type: %d, qkv_layout: %d, bias_type: %d, attn_mask_type: %d, softmax_type: %d, dropout: %f, h_q: %d, h_kv: %d, max_seqlen_q: %d, max_seqlen_kv: %d, d_qk: %d, d_v: %d, window_size_left: %d, window_size_right: %d, deterministic: %d\n", Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, deterministic); - printf("fused_attention_backend: %d\n", fused_attention_backend); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 0ef2dede76..08b9dca6d7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1301,7 +1301,6 @@ def forward( softmax_offset, cuda_graph=is_graph_capturing(), ) - print(f"out_.shape: {out_.shape}, {type(out_)}, qkv_layout: {qkv_layout}, o_format: {o_format}") # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1481,33 +1480,18 @@ def forward( def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - print(f"ctx.fp8: {ctx.fp8}, ctx.is_output_fp8: {ctx.is_output_fp8}, isinstance(d_out, QuantizedTensorStorage): {isinstance(d_out, QuantizedTensorStorage)}") - # # reshape d_out to ctx.qkv_layout; only happens with MXFP8BlockScaling - # if ctx.original_qkv_layout != ctx.qkv_layout: - # print(f"ctx.original_qkv_layout: {ctx.original_qkv_layout}, ctx.qkv_layout: {ctx.qkv_layout}") - # print(f"d_out before reshape: {d_out.shape}, {type(d_out)}") - # original_qkv_format = ctx.original_qkv_layout.split("_")[0] - # new_qkv_format = ctx.qkv_layout.split("_")[0] - # perm = [] - # for i in original_qkv_format: - # perm.append(new_qkv_format.find(i)) - # d_out = d_out.permute(*perm).contiguous() - # print(f"d_out after reshape: {d_out.shape}, {type(d_out)}") - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 d_out_fp8 = None d_out_format = ctx.o_format if ctx.fp8: - print(f"d_out before quantizer: {d_out.shape}, {type(d_out)}") if ctx.fp8_recipe.mxfp8(): d_out, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, d_out) if isinstance(d_out, QuantizedTensorStorage): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - print(f"d_out after quantizer: {d_out.shape}, {d_out_fp8._rowwise_data.shape}, {type(d_out)}, {type(d_out_fp8)}") if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( @@ -1602,14 +1586,6 @@ def backward(ctx, d_out, *_args): if ctx.fp8_recipe.mxfp8(): out_ = out aux_ctx_tensors.append(d_out) - print(f"q_fp8._with_gemm_swizzled_scales: {q_fp8._with_gemm_swizzled_scales}") - print(f"k_fp8._with_gemm_swizzled_scales: {k_fp8._with_gemm_swizzled_scales}") - print(f"v_fp8._with_gemm_swizzled_scales: {v_fp8._with_gemm_swizzled_scales}") - print(f"d_out_fp8._with_gemm_swizzled_scales: {d_out_fp8._with_gemm_swizzled_scales}") - # print(f"types: {type(q_fp8)}, {type(k_fp8)}, {type(v_fp8)}, {type(out_)}, {type(d_out_fp8)}, {[type(x) for x in aux_ctx_tensors]}") - # print(f"shapes: {q_fp8._rowwise_data.shape}, {k_fp8._rowwise_data.shape}, {v_fp8._rowwise_data.shape}, {out_.shape}, {d_out_fp8._rowwise_data.shape}, {[x.shape for x in aux_ctx_tensors]}") - print(f"out_.shape: {out_.shape}, d_out_fp8.shape: {d_out_fp8._rowwise_data.shape}, d_out_fp8.columnwise_data.shape: {d_out_fp8._columnwise_data.shape}, d_out.shape: {d_out.shape}") - print(f"out_.stride: {out_.stride()}, d_out_fp8.rowwise_data.stride: {d_out_fp8._rowwise_data.stride()}, d_out_fp8.columnwise_data.stride: {d_out_fp8._columnwise_data.stride()}, d_out.stride: {d_out.stride()}") dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1644,15 +1620,6 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) - print(f"dq_.shape: {dq_.shape}, dk_.shape: {dk_.shape}, dv_.shape: {dv_.shape}") - print(f"types: {type(dq_)}, {type(dk_)}, {type(dv_)}") - # if ctx.original_qkv_layout != ctx.qkv_layout: - # original_qkv_format = ctx.original_qkv_layout.split("_")[0] - # new_qkv_format = ctx.qkv_layout.split("_")[0] - # perm = [] - # for i in new_qkv_format: - # perm.append(original_qkv_format.find(i)) - # dq_, dk_, dv_ = [x.permute(*perm).contiguous() for x in (dq_, dk_, dv_)] # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 16410e8e00..03a52ab870 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2241,14 +2241,6 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] - print(f"q_fp8._rowwise_data.shape: {q_fp8._rowwise_data.shape}, k_fp8._rowwise_data.shape: {k_fp8._rowwise_data.shape}, v_fp8._rowwise_data.shape: {v_fp8._rowwise_data.shape}") - print(f"q_fp8._rowwise_scale_inv.shape: {q_fp8._rowwise_scale_inv.shape}, k_fp8._rowwise_scale_inv.shape: {k_fp8._rowwise_scale_inv.shape}, v_fp8._rowwise_scale_inv.shape: {v_fp8._rowwise_scale_inv.shape}") - print(f"q_fp8._columnwise_data.shape: {q_fp8._columnwise_data.shape}, k_fp8._columnwise_data.shape: {k_fp8._columnwise_data.shape}, v_fp8._columnwise_data.shape: {v_fp8._columnwise_data.shape}") - print(f"q_fp8._columnwise_scale_inv.shape: {q_fp8._columnwise_scale_inv.shape}, k_fp8._columnwise_scale_inv.shape: {k_fp8._columnwise_scale_inv.shape}, v_fp8._columnwise_scale_inv.shape: {v_fp8._columnwise_scale_inv.shape}") - print(f"q_fp8._rowwise_data.stride: {q_fp8._rowwise_data.stride()}, k_fp8._rowwise_data.stride: {k_fp8._rowwise_data.stride()}, v_fp8._rowwise_data.stride: {v_fp8._rowwise_data.stride()}") - print(f"q_fp8._rowwise_scale_inv.stride: {q_fp8._rowwise_scale_inv.stride()}, k_fp8._rowwise_scale_inv.stride: {k_fp8._rowwise_scale_inv.stride()}, v_fp8._rowwise_scale_inv.stride: {v_fp8._rowwise_scale_inv.stride()}") - print(f"q_fp8._columnwise_data.stride: {q_fp8._columnwise_data.stride()}, k_fp8._columnwise_data.stride: {k_fp8._columnwise_data.stride()}, v_fp8._columnwise_data.stride: {v_fp8._columnwise_data.stride()}") - print(f"q_fp8._columnwise_scale_inv.stride: {q_fp8._columnwise_scale_inv.stride()}, k_fp8._columnwise_scale_inv.stride: {k_fp8._columnwise_scale_inv.stride()}, v_fp8._columnwise_scale_inv.stride: {v_fp8._columnwise_scale_inv.stride()}") return q_fp8, k_fp8, v_fp8, qkv_layout diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index d1b1baed30..fd193b0258 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -92,13 +92,6 @@ std::pair quantizer_helper(py::handle quantizer, "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { - printf("in quantizer_helper\n"); - printf("create_hp_tensor_for_cs: %d\n", create_hp_tensor_for_cs); - printf("data.has_value(): %d\n", data.has_value()); - printf("shape: %d, %d, %d, %d, %d\n", shape[0], shape[1], shape[2], shape[3], shape[4]); - printf("dtype: %d\n", dtype); - printf("quantizer: %p\n", quantizer.ptr()); - // MXFP8 auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); if (create_hp_tensor_for_cs) { @@ -106,7 +99,6 @@ std::pair quantizer_helper(py::handle quantizer, std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); } else { - printf("in quantizer_helper, creating unquantized tensor with amax\n"); std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); } } else { @@ -163,8 +155,6 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; size_t b=0, h=0, s=0, d=0, t=0; nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); - printf("b: %d, h: %d, s: %d, d: %d, t: %d\n", b, h, s, d, t); - printf("o_shape: %d, %d, %d, %d, %d\n", o_shape[0], o_shape[1], o_shape[2], o_shape[3]); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -328,14 +318,6 @@ std::vector fused_attn_fwd( softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); - printf("after nvte_fused_attn_fwd\n"); - float *amax_cpu; - amax_cpu = (float *)malloc(sizeof(float)); - *amax_cpu=0.0; - cudaMemcpy(amax_cpu, te_O.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("O amax_cpu: %f\n", *amax_cpu); - // printf("py_O.amax(): %f\n", py_O.attr("amax").cast().cpu().item()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -394,7 +376,6 @@ std::vector fused_attn_bwd( if (!dqkv_quantizer.is_none()) { dqkv_type = dqkv_quantizer.attr("dtype").cast(); } - printf(">>>>>> dQKV_type: %d\n", dqkv_type); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); @@ -607,21 +588,6 @@ std::vector fused_attn_bwd( at::cuda::getCurrentCUDAStream()); }); - float *amax_cpu; - amax_cpu = (float *)malloc(sizeof(float)); - *amax_cpu=0.0; - cudaMemcpy(amax_cpu, te_dQ.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("dQ amax_cpu: %f\n", *amax_cpu); - cudaMemcpy(amax_cpu, te_dK.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("dK amax_cpu: %f\n", *amax_cpu); - cudaMemcpy(amax_cpu, te_dV.amax(), sizeof(float), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - printf("dV amax_cpu: %f\n", *amax_cpu); - // printf("py_dQ.amax(): %f\n", py_dQ.attr("amax").cast().cpu().item()); - // printf("py_dK.amax(): %f\n", py_dK.attr("amax").cast().cpu().item()); - // printf("py_dV.amax(): %f\n", py_dV.attr("amax").cast().cpu().item()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d5c3ea00b4..20820143b0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -949,9 +949,6 @@ std::pair MXFP8Quantizer::create_unquantized_tensor_w TensorWrapper out_cpp = std::move(out.first); py::object out_py = std::move(out.second); out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); - printf("after MXFP8Quantizer::create_unquantized_tensor_with_amax\n"); - printf("amax_ptr: %p\n", amax_tensor.data_ptr()); - printf("out_cpp.amax(): %f\n", amax_tensor.cpu().item()); return {std::move(out_cpp), std::move(out_py)}; } From 9f8e856a3db99b3bbb3898f361345c862e3a1bf9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:22:17 -0800 Subject: [PATCH 31/59] update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 4b4df2edcf..ae385ad82e 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 4b4df2edcf80b7cc7a659da5734454577154aa6d +Subproject commit ae385ad82e476bb75910d1ce92c6e25fdae42f40 From facef79b9dfc18fb04c12dcca63782ac50ecf222 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:42:42 -0800 Subject: [PATCH 32/59] update FE from pre-merge branch to post-merge develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index ae385ad82e..b4370f5198 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit ae385ad82e476bb75910d1ce92c6e25fdae42f40 +Subproject commit b4370f5198bd95ee758ebc2c6b76b887914b702d From fd33cca2dbe607ac2bed257d00eb65e90a30b896 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 13:07:36 -0800 Subject: [PATCH 33/59] allow MXFP8 linear + f16 attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f7699340e6..60b9812eb8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -98,19 +98,19 @@ +-------------------+-----------+-----------------------------------------------------------------------------------+ | Linear | Attention | Configuration | +===================+===========+===================================================================================+ -| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | -| | | export NVTE_DPA_FP8_RECIPE="F16" | +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS, NVFP4 or MXFP8 to autocast(); | +| /MXFP8 | | export NVTE_DPA_FP8_RECIPE="F16" | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8DS | Pass FP8DS to autocast(); | +| FP8DS | FP8DS | Pass FP8DS to autocast(); | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8DS | Pass FP8CS to autocast(); | +| FP8CS | FP8DS | Pass FP8CS to autocast(); | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | @@ -118,19 +118,19 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8CS | Pass FP8DS to autocast(); | +| FP8DS | FP8CS | Pass FP8DS to autocast(); | | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8CS | Pass FP8CS to autocast(); | +| FP8CS | FP8CS | Pass FP8CS to autocast(); | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | From 5079d5588be0016f2e2244e7fd459185340d5f27 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 17:58:21 -0800 Subject: [PATCH 34/59] test cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 27 +- tests/pytorch/attention/test_attention.py | 4 +- .../attention/test_attention_with_cp.py | 21 +- .../common/fused_attn/fused_attn_fp8.cu | 46 +++- transformer_engine/common/fused_attn/utils.h | 7 + .../dot_product_attention/context_parallel.py | 230 +++++++++++++----- .../attention/dot_product_attention/utils.py | 15 +- 7 files changed, 261 insertions(+), 89 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 3efb516b57..b019289846 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -19,8 +19,9 @@ DotProductAttention, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Format from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -189,7 +190,7 @@ def run_dpa_with_cp( os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" fp8_dpa = fp8_dpa == "True" and dtype == "fp8" fp8_mha = fp8_mha == "True" and dtype == "fp8" - f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True" os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -219,6 +220,7 @@ def run_dpa_with_cp( device_count = torch.cuda.device_count() device = rank % device_count torch.cuda.set_device(device) + print(f"rank: {rank}, world_size: {world_size}") logging.info(f"[Rank {rank}] Setup: world_size {world_size}") dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) @@ -244,6 +246,8 @@ def run_dpa_with_cp( fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( @@ -297,10 +301,25 @@ def run_dpa_with_cp( fp8_dtype=tex.DType.kFloat8E5M2, device="cuda", ) + if scaling_mode == "mxfp8": + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] if fp8_mha: - q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -386,7 +405,7 @@ def run_dpa_with_cp( dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) if fp8_mha: - q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 9922d93a77..dc0d37f555 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1788,8 +1788,8 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128),#, attn_mask_type="causal"), - "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128), #, num_gqa_groups=12, window_size=(512, 512)), + "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),#, attn_mask_type="causal"), + "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 836598087b..668c2745c7 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -17,6 +17,8 @@ from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils @@ -149,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 128),#, num_gqa_groups=12), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -166,7 +168,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA - "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA @@ -192,14 +194,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_1", "cp_1_4", "cp_2_0", + "cp_2_1", "cp_2_2", "cp_2_4", + "cp_3_1", "cp_3_2", "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] - qkv_formats = ["sbhd", "thd"] + qkv_formats = ["bshd", "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -211,7 +215,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O @@ -280,7 +284,7 @@ def test_cp_with_fused_attention( and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] ): pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode != "current"): + if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") @@ -301,6 +305,8 @@ def test_cp_with_fused_attention( "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" " non-vanilla softmax types!" ) + if scaling_mode == "mxfp8" and not f16_O: + pytest.skip("MXFP8 only works with f16_O=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -324,6 +330,11 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + if fp8 and scaling_mode == "mxfp8": + fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True) + fp8_meta["local_recipes"] = [ + MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True), + ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 57b250f6af..da826688be 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2399,17 +2399,17 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - // fe::DiagonalAlignment_t const &diagonal_alignment = - // bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT - // : fe::DiagonalAlignment_t::TOP_LEFT; - // sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); - // if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - // sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); - // } - // if (cudnn_runtime_version >= 90600 && window_size_right != -1) { - // sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); - // } + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } // sdpa_backward_options.set_alibi_mask(is_alibi); @@ -2632,9 +2632,33 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[dO_t] = devPtrdO_t; variant_pack[descale_q_t] = devPtrDescaleQ_t; variant_pack[descale_k_t] = devPtrDescaleK_t; - variant_pack[descale_dO] = devPtrDescaledO; + // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } + int64_t modulo = 16; + printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); + printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); + printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); + printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); + printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); + printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); + printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, is_aligned_modulo(devPtrDescaleQ, modulo)); + printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, is_aligned_modulo(devPtrDescaleK, modulo)); + printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, is_aligned_modulo(devPtrDescaleV, modulo)); + printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, is_aligned_modulo(devPtrDescaledO, modulo)); + printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, is_aligned_modulo(devPtrDescaledO_t, modulo)); + printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); + printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); + printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); + printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, is_aligned_modulo(devPtrAmaxdQ, modulo)); + printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, is_aligned_modulo(devPtrAmaxdK, modulo)); + printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, is_aligned_modulo(devPtrAmaxdV, modulo)); + printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); + printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); + printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, is_aligned_modulo(devPtrdO_f16, modulo)); + printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); + printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, is_aligned_modulo(devPtrDescaleQ_t, modulo)); + printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, is_aligned_modulo(devPtrDescaleK_t, modulo)); /* if (is_bias) { variant_pack[bias] = devPtrBias; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index f0b947c379..43d460bfd1 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -49,6 +49,13 @@ struct MXFP8PaddedSizes { int64_t d_v_scale_padded; }; +inline bool is_aligned_modulo(void* ptr, int64_t modulo) { + // Cast the pointer to a large enough integer type (uintptr_t) + uintptr_t address = reinterpret_cast(ptr); + // Check if the address is perfectly divisible by 16 + return (address % modulo) == 0; +} + // Pad s and d for MXFP8 layout inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { constexpr int64_t block_size = 32; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 22d8378598..b701c37fe6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -23,6 +23,8 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.common.recipe import MXFP8BlockScaling, Format from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.constants import ( @@ -58,6 +60,16 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def get_bsh_dims(tensor_format): + """Get batch dimension and sequence dimension from tensor format""" + if tensor_format in ["bshd", "sbhd", "bhsd"]: + batch_dim = tensor_format.index("b") + seq_dim = tensor_format.index("s") + head_dim = tensor_format.index("h") + else: # tensor_format == "thd" + batch_dim = seq_dim = tensor_format.index("t") + head_dim = tensor_format.index("h") + return batch_dim, seq_dim, head_dim def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm @@ -419,6 +431,7 @@ def flash_attn_a2a_communicate( ), "cu_seqlens_padded is required for THD format!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + batch_dim, _, head_dim = get_bsh_dims(qkv_format) if before_attn: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -430,13 +443,14 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # reorder the sequence chunks x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # or [b, np//cp, cp*2, s//2, hn] -> [b, np//cp, cp*s, hn] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) @@ -452,12 +466,14 @@ def flash_attn_a2a_communicate( x = a2a_inputs[i] # [b, s, np, hn] -> [b, s, cp, np//cp, hn] # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # or [b, np, s, hn] -> [b, cp, np//cp, s, hn] # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + x = x.view(*x.shape[:head_dim], cp_size, x.shape[head_dim] // cp_size, *x.shape[head_dim + 1:]) # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # or [b, cp, np//cp, s, hn] -> [cp, b, np//cp, s, hn] # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() + a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -467,9 +483,10 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # or [b, np//cp, cp*s, hn] -> [b, np//cp, cp*2, s//2, hn] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -486,10 +503,12 @@ def flash_attn_a2a_communicate( x = a2a_outputs[i - 2] # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # or [cp, 2, b, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + x = x.movedim(0, head_dim+1).movedim(0, seq_dim+1).contiguous() # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] # or [t, cp, np//cp, hn] -> [t, np, hn] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3367,21 +3386,19 @@ def forward( ), "The number of attention heads needs to be divisible by CP size!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - if qkv_format in ["bshd", "sbhd"]: - batch_dim = qkv_format.index("b") - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - batch_dim = seq_dim = qkv_format.index("t") + original_qkv_layout = qkv_layout + o_format = qkv_format + batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_o, _ = get_bsh_dims(o_format) assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; @@ -3392,6 +3409,9 @@ def forward( fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None + if torch.cuda.current_device() == 0: + print(f"is_input_fp8: {is_input_fp8}, is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}") + print(f"fp8: {fp8}, fp8_recipe: {fp8_recipe}") QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) @@ -3403,10 +3423,14 @@ def forward( fused_attn_backend = FusedAttnBackend["FP8"] if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v - elif not isinstance(QKV_quantizer, MXFP8Quantizer): + elif not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + # else: + # q, k, v = [q_fp8, k_fp8, v_fp8] + # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) + # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["o_quantizer"] = O_quantizer @@ -3417,11 +3441,15 @@ def forward( fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if torch.cuda.current_device() == 0: + print(f"before flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") + print(f"qkv_format: {qkv_format}, o_format: {o_format}") + print(f"batch_dim_qkv: {batch_dim_qkv}, seq_dim_qkv: {seq_dim_qkv}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, - seq_dim, + seq_dim_qkv, cp_size, cp_group, cp_stream, @@ -3429,6 +3457,8 @@ def forward( qkv_format=qkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True @@ -3436,15 +3466,20 @@ def forward( out_fp8 = None out_f16 = None - batch_size = q.shape[batch_dim] + batch_size = q.shape[batch_dim_qkv] q_part, k_part, v_part = q, k, v out_part = None if use_fused_attention: - if fp8: + if fp8 and not fp8_recipe.mxfp8(): q_part, k_part, v_part = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] + if fp8 and fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] + if torch.cuda.current_device() == 0: + print(f"before fused_attn_fwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}") out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3459,6 +3494,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3471,7 +3507,9 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if isinstance(out_, Float8Tensor): + if torch.cuda.current_device() == 0: + print(f"after fused_attn_fwd: out_: {out_.shape} {type(out_)}") + if isinstance(out_, QuantizedTensorStorage): out_fp8 = out_ out_ = out_._data if is_bwd_fp8 and not ( @@ -3487,6 +3525,7 @@ def forward( fp8 and is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() ): out_part = O_quantizer(out_) else: @@ -3516,33 +3555,39 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ + if torch.cuda.current_device() == 0: + print(f"before flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, - seq_dim, + seq_dim_o, cp_size, cp_group, cp_stream, before_attn=False, - qkv_format=qkv_format, + qkv_format=o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( *max_logit, 0, cp_size, cp_group, cp_stream, False ) if use_fused_attention: - if qkv_format == "bshd": + if o_format == "bshd": # [b*s, h, d] -> [b, s, h, d] out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out_ = out_.view(-1, batch_size, *out_.shape[-2:]) + if torch.cuda.current_device() == 0: + print(f"after view: out_: {out_.shape} {type(out_)}") if fp8 and use_fused_attention: - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_f16 = out_ if is_output_fp8: out_fp8 = O_quantizer(out_) @@ -3556,19 +3601,28 @@ def forward( out_ret = out_fp8 if is_output_fp8 else out_f16 ctx.fp8 = fp8 and is_bwd_fp8 + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) if ctx.fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): fp8_tensors = (q_part, k_part, v_part, None) f16_tensors = (None, None, None, out_part) else: fp8_tensors = (q_part, k_part, v_part, out_part) - elif fp8: + elif fp8 and not fp8_recipe.mxfp8(): q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) f16_tensors = (q_part, k_part, v_part, out_part) + elif fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_part) + ctx.qkv_layout = original_qkv_layout else: f16_tensors = (q_part, k_part, v_part, out_part) + if torch.cuda.current_device() == 0: + print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]}, type of fp8_tensors: {[type(x) for x in fp8_tensors]}") + print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]}, type of f16_tensors: {[type(x) for x in f16_tensors]}") tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3590,7 +3644,7 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format + # ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.deterministic = deterministic @@ -3612,11 +3666,13 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if return_max_logit: return out_ret, max_logit @@ -3644,27 +3700,28 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_format = ctx.qkv_format - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + # qkv_format = ctx.qkv_format + # qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + # qkv_layout = ctx.qkv_layout causal = "causal" in ctx.attn_mask_type + dqkv_format, _, _ = dpa_utils.get_qkv_format(ctx.dqkv_layout) - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - seq_dim = qkv_format.index("t") + batch_dim_dqkv, seq_dim_dqkv, _ = get_bsh_dims(dqkv_format) + _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype - dqkv_te_dtype = None + # dqkv_te_dtype = None fused_attn_backend = None dout_fp8 = dout if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorStorage): + if not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dout = ctx.dO_quantizer(dout) dout_fp8 = dout - dqkv_te_dtype = dout._fp8_dtype - dout = dout._data + if not ctx.fp8_recipe.mxfp8(): + # dqkv_te_dtype = dout._fp8_dtype + dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer @@ -3677,29 +3734,32 @@ def backward(ctx, dout, *_args): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} - dqkv_te_dtype = TE_DType[dout.dtype] + # dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: - if qkv_format in ["bshd", "sbhd"]: + if ctx.o_format in ["bshd", "sbhd"]: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) else: dout = dout.view(*ctx.out_shape) + if torch.cuda.current_device() == 0: + print(f"before flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, chunk_ids_for_a2a, - seq_dim, + seq_dim_do, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=True, - qkv_format=qkv_format, + qkv_format=ctx.o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3714,7 +3774,7 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: - if qkv_format == "thd": + if ctx.o_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( _flash_attn_varlen_bwd, ) @@ -3740,13 +3800,31 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["softcap"] = 0.0 dq_fp8, dk_fp8, dv_fp8 = None, None, None + d_out_format = ctx.o_format if ctx.use_fused_attention: q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8(): out_part = out - dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + if not ctx.fp8_recipe.mxfp8(): + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + else: + dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + dout_part = ctx.dO_quantizer(dout) + print(f"dout.ptr: {hex(dout.data_ptr())}, {hex(dout_part._rowwise_data.data_ptr())}, {hex(dout_part._columnwise_data.data_ptr())}, {hex(dout_part._rowwise_scale_inv.data_ptr())}, {hex(dout_part._columnwise_scale_inv.data_ptr())}") + aux_ctx_tensors.append(dout) + if torch.cuda.current_device() == 0: + print(f"before fused_attn_bwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}, out_part: {out_part.shape} {type(out_part)}, dout_part: {dout_part.shape} {type(dout_part)}") + print(f"type of aux_ctx_tensors: {[type(x) for x in aux_ctx_tensors]} {[x.shape if x is not None else None for x in aux_ctx_tensors]}") + print(f"fused_attn_backend: {fused_attn_backend}") + # print(f"cu_seqlens_q: {cu_seqlens_q.shape} {type(cu_seqlens_q)}, cu_seqlens_kv: {cu_seqlens_kv.shape} {type(cu_seqlens_kv)}") + # print(f"cu_seqlens_q_padded: {cu_seqlens_q_padded.shape} {type(cu_seqlens_q_padded)}, cu_seqlens_kv_padded: {cu_seqlens_kv_padded.shape} {type(cu_seqlens_kv_padded)}") + # print(f"ctx.softmax_scale: {ctx.softmax_scale}, ctx.dropout_p: {ctx.dropout_p}, ctx.window_size: {ctx.window_size}, ctx.deterministic: {ctx.deterministic}") + print(f"ctx.qkv_layout: {ctx.qkv_layout}, ctx.o_format: {ctx.o_format}, ctx.dqkv_layout: {ctx.dqkv_layout}") + # print(f"ctx.attn_mask_type: {ctx.attn_mask_type}, ctx.attn_bias_type: {ctx.attn_bias_type}") + print(f"is contiguous: {q_part.is_contiguous()}, {k_part.is_contiguous()}, {v_part.is_contiguous()}, {out_part.is_contiguous()}, {dout_part.is_contiguous()}") + print(fp8_meta_kwargs) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3758,14 +3836,17 @@ def backward(ctx, dout, *_args): out_part, dout_part, bwd_nominal_dtype, - dqkv_te_dtype, + # dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + d_out_format=d_out_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, @@ -3774,7 +3855,7 @@ def backward(ctx, dout, *_args): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if isinstance(dq, Float8Tensor): + if isinstance(dq, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: @@ -3783,7 +3864,7 @@ def backward(ctx, dout, *_args): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - qkv_format, + ctx.o_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -3806,22 +3887,27 @@ def backward(ctx, dout, *_args): **fa_backward_kwargs, ) + if torch.cuda.current_device() == 0: + print(f"after flash_attn_bwd: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") + print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, - seq_dim, + seq_dim_dqkv, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=False, - qkv_format=qkv_format, + qkv_format=dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - - if qkv_format == "bshd": + if torch.cuda.current_device() == 0: + print(f"after flash_attn_a2a_communicate: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") + print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") + if dqkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif qkv_format == "sbhd": + elif dqkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] d_bias = None @@ -3836,8 +3922,8 @@ def backward(ctx, dout, *_args): ) if ctx.fp8: - if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()) and ctx.is_input_fp8: + dq, dk, dv = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) @@ -3845,13 +3931,15 @@ def backward(ctx, dout, *_args): ] if not ctx.is_input_fp8: dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.dqkv_layout, dq, dk, dv, src_nominal_dtype=bwd_nominal_dtype, ) - + if torch.cuda.current_device() == 0: + print(f"after combine_and_dequantize: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") + print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( @@ -3982,7 +4070,19 @@ def attn_forward_func_with_cp( in Megatron-LM. """ - + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_comm_type=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {qkv_format=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {deterministic=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {use_fused_attention=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_meta=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_group=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_global_ranks=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_stream=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {quantizers=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {pad_between_seqs=}") + print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_output=}") + # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {layer_number=}") if cp_comm_type == "a2a+p2p": assert ( isinstance(cp_group, list) and len(cp_group) == 2 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 03a52ab870..eb2a7f8e94 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -515,6 +515,17 @@ def get_attention_backend( " with cuDNN < 9.18.0" ) use_fused_attention = False + if use_fused_attention and fp8_recipe.mxfp8(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for MXFP8 on arch < sm100") + use_fused_attention = False + else: + if cudnn_version < (9, 21, 0): + logger.debug("Disabling FusedAttention for MXFP8 with cuDNN < 9.21.0") + use_fused_attention = False + elif qkv_format == "thd": + logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: @@ -2096,7 +2107,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + QKV_quantizer.internal = False QKV_quantizer.set_usage(rowwise=True, columnwise=False) S_quantizer = quantizers["scaling_fwd"][META_S] @@ -2108,7 +2119,7 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): O_quantizer.set_usage(rowwise=True, columnwise=False) dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.internal = True + dO_quantizer.internal = False dO_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer = quantizers["scaling_bwd"][META_DP] From 06b7d491c6a819b6977bf6a7721351ffcdfaeb31 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 18:04:34 -0800 Subject: [PATCH 35/59] remove prints temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index b701c37fe6..78c19826c8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3409,9 +3409,6 @@ def forward( fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None - if torch.cuda.current_device() == 0: - print(f"is_input_fp8: {is_input_fp8}, is_output_fp8: {is_output_fp8}, is_bwd_fp8: {is_bwd_fp8}") - print(f"fp8: {fp8}, fp8_recipe: {fp8_recipe}") QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) @@ -3441,10 +3438,6 @@ def forward( fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - if torch.cuda.current_device() == 0: - print(f"before flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") - print(f"qkv_format: {qkv_format}, o_format: {o_format}") - print(f"batch_dim_qkv: {batch_dim_qkv}, seq_dim_qkv: {seq_dim_qkv}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], @@ -3457,8 +3450,6 @@ def forward( qkv_format=qkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: q: {q.shape} {type(q)}, k: {k.shape} {type(k)}, v: {v.shape} {type(v)}") if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True @@ -3478,8 +3469,6 @@ def forward( if fp8 and fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] - if torch.cuda.current_device() == 0: - print(f"before fused_attn_fwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}") out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3507,8 +3496,6 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if torch.cuda.current_device() == 0: - print(f"after fused_attn_fwd: out_: {out_.shape} {type(out_)}") if isinstance(out_, QuantizedTensorStorage): out_fp8 = out_ out_ = out_._data @@ -3555,8 +3542,6 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ - if torch.cuda.current_device() == 0: - print(f"before flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, @@ -3569,8 +3554,6 @@ def forward( qkv_format=o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: out_: {out_.shape} {type(out_)}") if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( *max_logit, 0, cp_size, cp_group, cp_stream, False @@ -3583,8 +3566,6 @@ def forward( elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - if torch.cuda.current_device() == 0: - print(f"after view: out_: {out_.shape} {type(out_)}") if fp8 and use_fused_attention: if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): @@ -3620,9 +3601,6 @@ def forward( ctx.qkv_layout = original_qkv_layout else: f16_tensors = (q_part, k_part, v_part, out_part) - if torch.cuda.current_device() == 0: - print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]}, type of fp8_tensors: {[type(x) for x in fp8_tensors]}") - print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]}, type of f16_tensors: {[type(x) for x in f16_tensors]}") tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3744,8 +3722,6 @@ def backward(ctx, dout, *_args): else: dout = dout.view(*ctx.out_shape) - if torch.cuda.current_device() == 0: - print(f"before flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, @@ -3758,8 +3734,6 @@ def backward(ctx, dout, *_args): qkv_format=ctx.o_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: dout: {dout.shape} {type(dout)}") flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3812,19 +3786,7 @@ def backward(ctx, dout, *_args): else: dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) dout_part = ctx.dO_quantizer(dout) - print(f"dout.ptr: {hex(dout.data_ptr())}, {hex(dout_part._rowwise_data.data_ptr())}, {hex(dout_part._columnwise_data.data_ptr())}, {hex(dout_part._rowwise_scale_inv.data_ptr())}, {hex(dout_part._columnwise_scale_inv.data_ptr())}") aux_ctx_tensors.append(dout) - if torch.cuda.current_device() == 0: - print(f"before fused_attn_bwd: q_part: {q_part.shape} {type(q_part)}, k_part: {k_part.shape} {type(k_part)}, v_part: {v_part.shape} {type(v_part)}, out_part: {out_part.shape} {type(out_part)}, dout_part: {dout_part.shape} {type(dout_part)}") - print(f"type of aux_ctx_tensors: {[type(x) for x in aux_ctx_tensors]} {[x.shape if x is not None else None for x in aux_ctx_tensors]}") - print(f"fused_attn_backend: {fused_attn_backend}") - # print(f"cu_seqlens_q: {cu_seqlens_q.shape} {type(cu_seqlens_q)}, cu_seqlens_kv: {cu_seqlens_kv.shape} {type(cu_seqlens_kv)}") - # print(f"cu_seqlens_q_padded: {cu_seqlens_q_padded.shape} {type(cu_seqlens_q_padded)}, cu_seqlens_kv_padded: {cu_seqlens_kv_padded.shape} {type(cu_seqlens_kv_padded)}") - # print(f"ctx.softmax_scale: {ctx.softmax_scale}, ctx.dropout_p: {ctx.dropout_p}, ctx.window_size: {ctx.window_size}, ctx.deterministic: {ctx.deterministic}") - print(f"ctx.qkv_layout: {ctx.qkv_layout}, ctx.o_format: {ctx.o_format}, ctx.dqkv_layout: {ctx.dqkv_layout}") - # print(f"ctx.attn_mask_type: {ctx.attn_mask_type}, ctx.attn_bias_type: {ctx.attn_bias_type}") - print(f"is contiguous: {q_part.is_contiguous()}, {k_part.is_contiguous()}, {v_part.is_contiguous()}, {out_part.is_contiguous()}, {dout_part.is_contiguous()}") - print(fp8_meta_kwargs) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3887,9 +3849,6 @@ def backward(ctx, dout, *_args): **fa_backward_kwargs, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_bwd: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") - print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -3902,9 +3861,6 @@ def backward(ctx, dout, *_args): qkv_format=dqkv_format, cu_seqlens_padded=cu_seqlens_q_padded, ) - if torch.cuda.current_device() == 0: - print(f"after flash_attn_a2a_communicate: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") - print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") if dqkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif dqkv_format == "sbhd": @@ -3937,9 +3893,6 @@ def backward(ctx, dout, *_args): dv, src_nominal_dtype=bwd_nominal_dtype, ) - if torch.cuda.current_device() == 0: - print(f"after combine_and_dequantize: dq: {dq.shape}, dk: {dk.shape}, dv: {dv.shape}") - print(f"type of dq: {type(dq)}, type of dk: {type(dk)}, type of dv: {type(dv)}") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( From 7fbe399c80e5fa177a1cfaceb3422286c9773289 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:58:41 -0800 Subject: [PATCH 36/59] test cp p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../dot_product_attention/context_parallel.py | 208 ++++++++++++------ .../attention/dot_product_attention/utils.py | 1 + 2 files changed, 138 insertions(+), 71 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 78c19826c8..dda856f36a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -794,13 +794,16 @@ def cp_p2p_fwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step, O_quantizer_per_step, rank, @@ -875,11 +878,15 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_kv_padded_ = cu_seqlens_kv_padded fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -896,7 +903,8 @@ def cp_p2p_fwd_fused_attn( fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs, @@ -915,7 +923,7 @@ def cp_p2p_fwd_fused_attn( if return_max_logit: return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit - return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None #, new_qkv_layout def cp_p2p_fwd_flash_attn( @@ -1073,15 +1081,21 @@ def cp_p2p_bwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, + d_out_format, + dqkv_layout, attn_mask_type, attn_bias_type, deterministic, fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, S_quantizer, dP_quantizer_per_step, dQKV_quantizer_per_step, + # O_quantizer_per_step, + QKV_quantizer_per_step, + dO_quantizer_per_step, q_part, k_part, v_part, @@ -1131,16 +1145,26 @@ def cp_p2p_bwd_fused_attn( fp8_meta_kwargs = {} if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip( - [q_fp8, kv_fp8, kv_fp8], - [q_part, k_part, v_part], - ) - ] - if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): - out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step) + if not fp8_recipe.mxfp8(): + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + else: + # out_part, o_format = dpa_utils.permute_to_grouped_tensor(o_format, out_part) + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) + # out_part = O_quantizer_per_step(out_part) + aux_tensors.append(dout_part) + dout_part = dO_quantizer_per_step(dout_part) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step @@ -1156,7 +1180,7 @@ def cp_p2p_bwd_fused_attn( out_part, dout_part, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, aux_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -1164,6 +1188,9 @@ def cp_p2p_bwd_fused_attn( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, + d_out_format=d_out_format, + dqkv_layout=dqkv_layout, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, deterministic=deterministic, @@ -1405,13 +1432,13 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - else: + elif not fp8_recipe.mxfp8(): # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers @@ -1432,10 +1459,11 @@ def forward( # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + S_quantizer_per_step[i] = S_quantizer.copy() if S_quantizer is not None else None O_quantizer_per_step[i] = O_quantizer.copy() - O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not fp8_recipe.mxfp8(): + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype @@ -1555,7 +1583,9 @@ def forward( # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + o_format = qkv_format for i in range(cp_size + 1): + print(f">>>>>>>>>>>> {torch.cuda.current_device()}: i: {i}, cp_size: {cp_size}") if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): # wait until KV is received @@ -1608,13 +1638,16 @@ def forward( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step[i], O_quantizer_per_step[i], rank, @@ -1666,6 +1699,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1693,6 +1727,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1720,6 +1755,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1748,6 +1784,7 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], + # qkv_layout, ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1775,7 +1812,7 @@ def forward( out_per_step[i - 1] = out_per_step[i - 1].dequantize( dtype=torch.float32 ) - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) if i == 1: @@ -1829,7 +1866,7 @@ def forward( # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: if i == 0: out = flash_attn_fwd_out_correction_init( out_per_step[0], @@ -1849,7 +1886,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1860,7 +1897,7 @@ def forward( softmax_lse_in_packed_format, ) else: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: flash_attn_fwd_second_half_out_correction( out, out_per_step[i], @@ -1868,7 +1905,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1879,10 +1916,10 @@ def forward( softmax_lse_in_packed_format, ) - if qkv_format == "bshd": + if o_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] - elif qkv_format == "sbhd": + elif o_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] @@ -1892,10 +1929,10 @@ def forward( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) if use_fused_attention: - if qkv_format == "bshd": + if o_format == "bshd": # [b*s, h, d] -> [b, s, h, d] out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) if return_max_logit: @@ -1906,7 +1943,7 @@ def forward( out = out.view(-1, *out.shape[-2:]) # update FP8 quantizers: amax across cp_size steps - if fp8 and use_fused_attention: + if fp8 and use_fused_attention and not fp8_recipe.mxfp8(): amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) O_quantizer.amax.copy_(amax_cp_fwd[1]) @@ -1929,7 +1966,7 @@ def forward( out_f16 = out.to(fwd_nominal_dtype) if fp8 and ( is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) and not fp8_recipe.mxfp8()) ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 @@ -1940,7 +1977,8 @@ def forward( kv_fp8 = None kv = p2p_comm_buffers[-1] - if fp8: + q_fp8, kv_fp8 = None, None + if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8], [q, kv]) @@ -1953,12 +1991,22 @@ def forward( fp8_tensors = (q_fp8, kv_fp8, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: f16_tensors = (None, None, out_f16) - elif fp8 and is_input_fp8: + elif fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) + elif fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): # fwd: fp8, bwd: f16, save all f16 # dequantize fp8 inputs q_f16 = q_fp8.dequantize() kv_f16 = kv_fp8.dequantize() f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and is_input_fp8 and fp8_recipe.mxfp8(): + # fwd: fp8, bwd: f16, save all f16 + # there is already an F16 version of the inputs + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q, k, v) + kv = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) + f16_tensors = (q_f16, kv, out_f16) + elif fp8 and not is_input_fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) elif fp8: # fwd: fp8, bwd: f16, save all f16 # inputs are already in f16 @@ -1971,6 +2019,9 @@ def forward( q_f16 = q_f16.view(q.shape) kv_f16 = kv f16_tensors = (q_f16, kv_f16, out_f16) + if torch.cuda.current_device() == 0: + print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") + print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -2023,11 +2074,12 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop(f"{nvtx_label}") @@ -2045,7 +2097,7 @@ def backward(ctx, dout, *_args): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2086,6 +2138,7 @@ def backward(ctx, dout, *_args): causal = "causal" in ctx.attn_mask_type seq_dim = None qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + o_format = ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") @@ -2142,28 +2195,33 @@ def backward(ctx, dout, *_args): buffer_dtype = torch.uint8 dq_buffer = None dout_fp8 = None - bwd_output_te_dtype = None + # bwd_output_te_dtype = None dkv_buffer = None + d_out_format = o_format if ctx.fp8: assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = FusedAttnBackend["FP8"] - q, kv, out = ( - q_fp8._data, - kv_fp8._data, - ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8._data - ), - ) + if not ctx.fp8_recipe.mxfp8(): + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype # dout: torch.Tensor, dtype=torch.uint8 - if ctx.is_output_fp8: + # if ctx.fp8_recipe.mxfp8(): + # dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout - else: + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) - dout = dout_fp8._data + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data # print quantizers print_quantizers( @@ -2178,7 +2236,7 @@ def backward(ctx, dout, *_args): ) # dout_fp8._fp8_dtype - bwd_output_te_dtype = ctx.dO_quantizer.dtype + # bwd_output_te_dtype = ctx.dO_quantizer.dtype # create buffers for reduction in float32 if ctx.fp8_recipe.delayed(): @@ -2187,7 +2245,7 @@ def backward(ctx, dout, *_args): dtype=buffer_dtype, device=q.device, ) - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_buffer = torch.empty( q.shape, dtype=torch.float32, @@ -2201,7 +2259,7 @@ def backward(ctx, dout, *_args): ) dkv_recv_buffer = torch.empty_like(dkv_send_buffer) p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dkv_buffer = torch.zeros( kv.shape, dtype=torch.float32, @@ -2214,10 +2272,11 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() - dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not ctx.fp8_recipe.mxfp8(): + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) @@ -2228,7 +2287,7 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] + # bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # communicate for the 'a2a' part of 'a2a+p2p' @@ -2352,10 +2411,10 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8() else out_fp8 ), - dout_fp8, + dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, softmax_lse, softmax_lse_, rng_states, @@ -2373,15 +2432,21 @@ def backward(ctx, dout, *_args): ctx.softmax_scale, ctx.dropout_p, qkv_layout, + ctx.qkv_format, + ctx.qkv_format, + qkv_layout, ctx.attn_mask_type, ctx.attn_bias_type, ctx.deterministic, ctx.fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, + # bwd_output_te_dtype, ctx.S_quantizer, dP_quantizer_per_step[i], dQKV_quantizer_per_step[i], + # ctx.O_quantizer, + ctx.QKV_quantizer, + ctx.dO_quantizer, ] else: flash_attn_inputs = [ @@ -2455,7 +2520,7 @@ def backward(ctx, dout, *_args): if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8_recipe.delayed(): dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] # copy dq_ into the right buffer position @@ -2539,7 +2604,7 @@ def backward(ctx, dout, *_args): # dkv correction if ctx.fp8 and ctx.fp8_recipe.delayed(): dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] - elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + elif ctx.fp8 and (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()): dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] @@ -2629,9 +2694,10 @@ def backward(ctx, dout, *_args): # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + if not ctx.fp8_recipe.mxfp8(): + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) dq = dq_buffer if ctx.fp8_recipe.delayed(): @@ -2654,7 +2720,7 @@ def backward(ctx, dout, *_args): ) dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dk = dkv[: ctx.k_numel].view(ctx.k_shape) dv = dkv[ctx.k_numel :].view(ctx.v_shape) @@ -2670,7 +2736,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index eb2a7f8e94..05301c186d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2236,6 +2236,7 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] + print(f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}, s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}") assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 From aa05a2afa644505dbae63ea3bc7779f6ce948c30 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:34:45 -0800 Subject: [PATCH 37/59] minor fixes for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 3 +- .../attention/test_attention_with_cp.py | 6 ++-- tests/pytorch/utils.py | 1 + .../dot_product_attention/context_parallel.py | 34 +++++++++++++------ .../attention/dot_product_attention/utils.py | 12 +++---- 5 files changed, 35 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index b019289846..a53d872302 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -465,7 +465,8 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[4] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): + print(f"========= {torch.cuda.current_device()}: tensors[{i}].shape: {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}") assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 668c2745c7..2ab64d2029 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -151,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 128),#, num_gqa_groups=12), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -288,8 +288,8 @@ def test_cp_with_fused_attention( pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") - if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently does not support FP8 attention!") + # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: + # pytest.skip("MLA CP currently does not support FP8 attention!") if dtype == "fp8" and config.softmax_type != "vanilla": pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") if config.softmax_type != "vanilla" and cp_comm_type != "a2a": diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index c54295d478..ff8cb3e820 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -177,6 +177,7 @@ def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): rmse = torch.sqrt((a - b).square().mean()).item() logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + # rmse_tol = rmse_tol * 1.1 assert rmse < rmse_tol * rmse_range, ( name_a + " vs " diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index dda856f36a..864967d661 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1408,15 +1408,17 @@ def forward( q_fp8, k_fp8, v_fp8 = (None, None, None) # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: + print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}") if fp8 and is_input_fp8: QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = (q._data, k._data, v._data) + if not fp8_recipe.mxfp8(): + q, k, v = (q._data, k._data, v._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) - if fp8 and is_input_fp8: + if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) @@ -1576,6 +1578,7 @@ def forward( k_shape = k.shape k_numel = k.numel() v_shape = v.shape + o_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] @@ -1820,7 +1823,7 @@ def forward( if qkv_format == "thd": if enable_mla: out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape + o_shape ) else: # MHA or GQA @@ -1874,8 +1877,9 @@ def forward( softmax_lse_per_step[0], seq_dim, ) + print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}") if enable_mla: - out = out.view(v_shape) + out = out.view(o_shape) else: out = out.view(q.shape) else: @@ -1922,6 +1926,9 @@ def forward( elif o_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] + print(f"========= {torch.cuda.current_device()}: out.shape: {out.shape} {out.dtype}") + out_part = out.to(fwd_nominal_dtype) + print(f"========= {torch.cuda.current_device()}: out_part.shape: {out_part.shape} {out_part.dtype}") if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) @@ -1986,6 +1993,7 @@ def forward( # q, kv, out fp8_tensors = (None, None, None) f16_tensors = (None, None, None) + out_f16 = out_part if ctx.fp8: # fwd: fp8, bwd: fp8, save all fp8 fp8_tensors = (q_fp8, kv_fp8, out_fp8) @@ -2064,6 +2072,7 @@ def forward( ctx.k_numel = k_numel ctx.k_shape = k_shape ctx.v_shape = v_shape + ctx.o_shape = o_shape ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer @@ -2292,14 +2301,15 @@ def backward(ctx, dout, *_args): # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: + print(f"========= {torch.cuda.current_device()}: before a2a: out.shape: {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}") if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) - out, dout = flash_attn_a2a_communicate( - [out, dout], + dout = flash_attn_a2a_communicate( + [dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, @@ -2307,10 +2317,11 @@ def backward(ctx, dout, *_args): ctx.cp_stream, True, ) + print(f"========= {torch.cuda.current_device()}: after a2a: dout.shape: {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}") if ctx.enable_mla: - out = out.view(*ctx.v_shape) - dout = dout.view(*ctx.v_shape) + out = out.view(*ctx.o_shape) + dout = dout.view(*ctx.o_shape) else: # MHA or GQA out = out.view(*q.shape) @@ -2754,7 +2765,8 @@ def backward(ctx, dout, *_args): if cp_size_a2a > 1: if ctx.fp8 and ctx.is_input_fp8: dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv - dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) + if not ctx.fp8_recipe.mxfp8(): + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2765,7 +2777,7 @@ def backward(ctx, dout, *_args): ctx.cp_stream, False, ) - if ctx.fp8 and ctx.is_input_fp8: + if ctx.fp8 and ctx.is_input_fp8 and not ctx.fp8_recipe.mxfp8(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 05301c186d..6f36aee355 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -835,12 +835,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with FP8" - " MLA attention" - ) - use_fused_attention = False + # elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: + # logger.debug( + # "Disabling FusedAttention as it does not support context parallelism with FP8" + # " MLA attention" + # ) + # use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends From 00e6693f978dbc877af63efb069429d840f786ed Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:51:46 -0800 Subject: [PATCH 38/59] open up a2a for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention_with_cp.py | 2 +- .../pytorch/attention/dot_product_attention/context_parallel.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2ab64d2029..dc8c237d57 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -286,7 +286,7 @@ def test_cp_with_fused_attention( pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently does not support FP8 attention!") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 864967d661..32f908993e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4163,6 +4163,7 @@ def attn_forward_func_with_cp( assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", + "a2a", ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: From b8d28ceb6b92a811385c0d377756b9fa6d19c750 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:33:44 -0800 Subject: [PATCH 39/59] test ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/test_attention_with_cp.py | 10 +- .../dot_product_attention/context_parallel.py | 244 +++++++++++++++--- 2 files changed, 213 insertions(+), 41 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index dc8c237d57..efed75925a 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -151,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 128), #192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -249,10 +249,10 @@ def test_cp_with_fused_attention( "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" " yet!" ) - if dtype == "fp8" and cp_comm_type == "all_gather": - pytest.skip( - "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - ) + # if dtype == "fp8" and cp_comm_type == "all_gather": + # pytest.skip( + # "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" + # ) if dtype == "fp8" and qkv_format == "thd": pytest.skip("FP8 attention cannot work with THD format yet!") if dtype == "fp8" and config.attn_bias_type != "no_bias": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 32f908993e..88c98ce041 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2883,6 +2883,10 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") @@ -2892,7 +2896,11 @@ def forward( cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) - qkv_dtype = q.dtype + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + q_shape = q.shape + k_shape = k.shape + v_shape = v.shape causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -2936,9 +2944,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - seq_dim = qkv_format.index("s") assert ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 @@ -2953,6 +2958,42 @@ def forward( else: cu_seqlens_q_padded = None + # FP8 setup + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + ( + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + fwd_nominal_dtype = q.dtype + fp8_meta_kwargs = {} + q_fp8, k_fp8, v_fp8 = (None, None, None) + fused_attn_backend = None + if fp8: + assert use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + if is_input_fp8: + q_fp8, k_fp8, v_fp8 = q, k, v + else: + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer + elif use_fused_attention: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] @@ -2983,7 +3024,9 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) + enable_mla = k.shape[-1] != v.shape[-1] + out_shape = q.shape if not enable_mla else q.shape[:-1] + v.shape[-1:] + out = torch.empty(out_shape, dtype=fwd_nominal_dtype, device=q.device) max_logit_per_step = [None, None] max_logit = None @@ -3016,6 +3059,14 @@ def forward( # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: + q_part, k_part, v_part = q_, k_, v_ + if fp8: + if not fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like(q_fp8, data=q_, dtype=fwd_nominal_dtype) + k_part = Float8Tensor.make_like(k_fp8, data=k_, dtype=fwd_nominal_dtype) + v_part = Float8Tensor.make_like(v_fp8, data=v_, dtype=fwd_nominal_dtype) + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) ( out_per_step[i], [softmax_lse_per_step[i], rng_states[i]], @@ -3026,14 +3077,15 @@ def forward( max_seqlen_kv_, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - qkv_dtype, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + q_part, + k_part, + v_part, + fwd_nominal_dtype, + fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3042,9 +3094,12 @@ def forward( window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) if return_max_logit: max_logit_per_step[i] = max_logit_[0] + if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): + out_per_step[i] = out_per_step[i].dequantize(dtype=fwd_nominal_dtype) else: fa_forward_args_thd = get_fa_args( True, @@ -3104,10 +3159,38 @@ def forward( else: out = out.view(-1, *out.shape[-2:]) - ctx.save_for_backward( - q, - k, - v, + out_fp8 = None + out_ret = out + if fp8 and (is_output_fp8 or (is_bwd_fp8 and fp8_recipe.delayed())): + out_fp8 = O_quantizer(out) + out_ret = out_fp8 + ctx.fp8 = fp8 and is_bwd_fp8 + ctx.fp8_recipe = fp8_recipe + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + if ctx.fp8: + if fp8_recipe.delayed(): + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + f16_tensors = (None, None, None, out) + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out) + elif fp8: + if is_input_fp8: + q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + f16_tensors = (q, k, v, out) + else: + f16_tensors = (q, k, v, out) + if torch.cuda.current_device() == 0: + print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") + print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") + + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_q_padded, *cu_seqlens_kv_per_step, @@ -3115,8 +3198,14 @@ def forward( *softmax_lse_per_step, *rng_states, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects - ctx.qkv_dtype = qkv_dtype + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.q_shape = q_shape + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.out_shape = out_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group @@ -3130,10 +3219,24 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 + if fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.O_quantizer = O_quantizer.copy() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + ctx.dQKV_quantizer = dQKV_quantizer.copy() + ctx.dO_quantizer = dO_quantizer.copy() + ctx.dP_quantizer = dP_quantizer.copy() if dP_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: - return out, max_logit - return out + return out_ret, max_logit + return out_ret @staticmethod def backward(ctx, dout, *_args): @@ -3142,22 +3245,41 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] - cu_seqlens_kv_per_step = saved_tensors[5:7] - out_per_step = saved_tensors[7:9] - softmax_lse_per_step = saved_tensors[9:11] - rng_states = saved_tensors[11:13] + ( + q_fp8, k_fp8, v_fp8, out_fp8, + q, k, v, out, + cu_seqlens_q, + cu_seqlens_q_padded, + cu_seqlens_kv_per_step, + out_per_step, + softmax_lse_per_step, + rng_states + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - dout = dout.view(q.shape) - dq = torch.empty_like(q) - dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) - dv = torch.zeros_like(dk) + dout = dout.view(ctx.out_shape) + dout_fp8 = None + if ctx.fp8: + if ( + ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): + dout_fp8 = ctx.dO_quantizer(dout) + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data + + if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): + q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + + dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + dk = torch.zeros((ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=k.device) + dv = torch.zeros((ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=v.device) dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3233,31 +3355,65 @@ def backward(ctx, dout, *_args): dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8_meta_kwargs = {} + if ctx.fp8: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + if not ctx.fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like( + q_fp8, data=q_, dtype=ctx.fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_, dtype=ctx.fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_, dtype=ctx.fwd_nominal_dtype + ) + if not (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = ctx.O_quantizer(out_part) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype) + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) + aux_ctx_tensors.append(dout_part) + dout_part = ctx.dO_quantizer(dout_part) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - out_, - dout_, - ctx.qkv_dtype, - TE_DType[dout.dtype], + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.fwd_nominal_dtype, + # TE_DType[dout.dtype], aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, + o_format=ctx.qkv_format, + d_out_format=ctx.qkv_format, + dqkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + if ctx.fp8: + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + x.dequantize(dtype=ctx.fwd_nominal_dtype) if isinstance(x, QuantizedTensorStorage) else x + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ] else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] @@ -3335,6 +3491,10 @@ def backward(ctx, dout, *_args): dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) dk = dk.movedim(0, seq_dim).contiguous() dv = dv.movedim(0, seq_dim).contiguous() + + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( @@ -3359,6 +3519,9 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, + None, ) @@ -4222,7 +4385,16 @@ def attn_forward_func_with_cp( elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [ + window_size, + cp_group, + cp_stream, + use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, + ] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ From d6ecadc12c2192bc443167f7efad3e201c77e763 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Feb 2026 20:32:22 -0800 Subject: [PATCH 40/59] tweaks for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 2 +- .../dot_product_attention/context_parallel.py | 58 ++++++++++++++----- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index da826688be..764e95c330 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2891,7 +2891,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); - if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 88c98ce041..9dea649633 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2898,9 +2898,6 @@ def forward( assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - q_shape = q.shape - k_shape = k.shape - v_shape = v.shape causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -2985,7 +2982,7 @@ def forward( fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v - else: + elif not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] @@ -2996,8 +2993,11 @@ def forward( # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + q_shape = q.shape # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + k_shape = k.shape + v_shape = v.shape # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) @@ -3060,16 +3060,17 @@ def forward( k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: q_part, k_part, v_part = q_, k_, v_ + new_qkv_layout = qkv_layout if fp8: if not fp8_recipe.mxfp8(): q_part = Float8Tensor.make_like(q_fp8, data=q_, dtype=fwd_nominal_dtype) k_part = Float8Tensor.make_like(k_fp8, data=k_, dtype=fwd_nominal_dtype) v_part = Float8Tensor.make_like(v_fp8, data=v_, dtype=fwd_nominal_dtype) else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) ( out_per_step[i], - [softmax_lse_per_step[i], rng_states[i]], + aux_ctx_tensors, *max_logit_, ) = fused_attn_fwd( is_training, @@ -3084,7 +3085,7 @@ def forward( fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, o_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, @@ -3096,6 +3097,10 @@ def forward( cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors if return_max_logit: max_logit_per_step[i] = max_logit_[0] if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): @@ -3245,15 +3250,23 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) + cu_seqlens_kv_per_step = [None, None] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] ( q_fp8, k_fp8, v_fp8, out_fp8, q, k, v, out, cu_seqlens_q, cu_seqlens_q_padded, - cu_seqlens_kv_per_step, - out_per_step, - softmax_lse_per_step, - rng_states + cu_seqlens_kv_per_step[0], + cu_seqlens_kv_per_step[1], + out_per_step[0], + out_per_step[1], + softmax_lse_per_step[0], + softmax_lse_per_step[1], + rng_states[0], + rng_states[1], ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) kv_seq_range_per_step = ctx.kv_seq_range_per_step @@ -3277,9 +3290,15 @@ def backward(ctx, dout, *_args): if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if torch.cuda.current_device() == 0: + print(f"ctx.q_shape: {ctx.q_shape} {ctx.k_shape} {ctx.v_shape}") dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) dk = torch.zeros((ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=k.device) dv = torch.zeros((ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=v.device) + if torch.cuda.current_device() == 0: + print(f"dq: {dq.shape} {dq.dtype} {dq.device}") + print(f"dk: {dk.shape} {dk.dtype} {dk.device}") + print(f"dv: {dv.shape} {dv.dtype} {dv.device}") dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3354,10 +3373,12 @@ def backward(ctx, dout, *_args): out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + aux_ctx_tensors = [softmax_lse_per_step[i], softmax_lse_per_step[i], rng_states[i]] q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout + d_out_format = ctx.qkv_format if ctx.fp8: fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3377,9 +3398,13 @@ def backward(ctx, dout, *_args): out_part = ctx.O_quantizer(out_part) dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype) else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) + print(f"aux_ctx_tensors: {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}") + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) aux_ctx_tensors.append(dout_part) dout_part = ctx.dO_quantizer(dout_part) + print(f"q_part type: {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}") + print(f"q_part shape: {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}") dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, @@ -3398,9 +3423,9 @@ def backward(ctx, dout, *_args): cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, o_format=ctx.qkv_format, - d_out_format=ctx.qkv_format, + d_out_format=d_out_format, dqkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, @@ -3453,6 +3478,8 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): + if torch.cuda.current_device() == 0: + print(f"dq.shape: {dq.shape} dq_per_step[i - 1].shape: {dq_per_step[i - 1].shape}") if ctx.qkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": @@ -3522,6 +3549,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, ) From 3ac48cd095799f79bd06fbe126edd3237bef267a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 13:51:36 -0800 Subject: [PATCH 41/59] enable mla ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention_with_cp.py | 6 +++--- .../dot_product_attention/context_parallel.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index efed75925a..1ac9dc7398 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -151,7 +151,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 128), #192, head_dim_v=128, num_gqa_groups=4, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128), #num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -286,8 +286,8 @@ def test_cp_with_fused_attention( pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently only support KV P2P!") + # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: + # pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently does not support FP8 attention!") if dtype == "fp8" and config.softmax_type != "vanilla": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 9dea649633..b8a50e8b77 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4351,11 +4351,11 @@ def attn_forward_func_with_cp( ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] - assert not enable_mla or cp_comm_type in [ - "p2p", - "a2a+p2p", - "a2a", - ], f"Context parallelism does not support MLA with {cp_comm_type=}!" + # assert not enable_mla or cp_comm_type in [ + # "p2p", + # "a2a+p2p", + # "a2a", + # ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: From 5d4fa5e2038bd5e21747ab0bc69dbff93fdab847 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:36:49 +0000 Subject: [PATCH 42/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 11 +- tests/pytorch/attention/test_attention.py | 6 +- .../attention/test_attention_with_cp.py | 4 +- tests/pytorch/test_grouped_tensor.py | 3 +- .../common/fused_attn/fused_attn.cpp | 88 +-- .../common/fused_attn/fused_attn_fp8.cu | 510 ++++++++++-------- .../common/fused_attn/fused_attn_fp8.h | 41 +- transformer_engine/common/fused_attn/utils.cu | 2 +- transformer_engine/common/fused_attn/utils.h | 468 ++++++++-------- .../include/transformer_engine/fused_attn.h | 39 +- .../transformer_engine/transformer_engine.h | 385 ++++++------- transformer_engine/common/recipe/__init__.py | 1 + .../common/transformer_engine.cpp | 212 ++++---- .../dot_product_attention/backends.py | 53 +- .../dot_product_attention/context_parallel.py | 202 +++++-- .../attention/dot_product_attention/utils.py | 17 +- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions.h | 6 +- .../pytorch/csrc/extensions/attention.cpp | 51 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- .../pytorch/tensor/storage/grouped_tensor.py | 2 +- 21 files changed, 1197 insertions(+), 914 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index c9d6d9d64f..5cb43f277a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -21,7 +21,12 @@ Float8CurrentScalingQuantizer, MXFP8Quantizer, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Format +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, + Format, +) from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -250,7 +255,9 @@ def run_dpa_with_cp( if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + ) # instantiate attention module core_attn = DotProductAttention( diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5760ca2434..47abf1ebc6 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,8 +1804,10 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),#, attn_mask_type="causal"), - "fp8_10": ModelConfig(2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512)), + "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), + "fp8_10": ModelConfig( + 2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512) + ), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9e166fa908..a5fe8f74f5 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -154,7 +154,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 16, 192, head_dim_v=128), #num_gqa_groups=4, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig( + 2, 4096, 16, 192, head_dim_v=128 + ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index ab9ec28984..31d84933de 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -361,7 +361,6 @@ def test_static_quantize_method(self, quantization: str) -> None: expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset - @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) def test_quantize_grouped_mxfp8(self) -> None: """Test grouped quantization for MXFP8 against per-tensor quantization.""" @@ -372,7 +371,7 @@ def test_quantize_grouped_mxfp8(self) -> None: # Create BF16 input tensors and pack into a grouped tensor input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shapes] quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) - quantizer.optimize_for_gemm=True + quantizer.optimize_for_gemm = True grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shapes=shapes, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1f5db127a0..72c5273a78 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -141,9 +141,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } // map one NVTE_QKV_Format to another -void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, - size_t *b, size_t *h, size_t *s, size_t *d, size_t *t) { - size_t _b=0, _h=0, _s=0, _d=0, _t=0; +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t) { + size_t _b = 0, _h = 0, _s = 0, _d = 0, _t = 0; switch (src_format) { case NVTE_QKV_Format::NVTE_BSHD: _b = src_shape[0]; @@ -270,8 +271,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: mxfp8, d_qk=128, d_v=192 - (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && + // 9.21: mxfp8, d_qk=128, d_v=192 + (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -411,13 +412,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || q_format == NVTE_QKV_Format::NVTE_BHSD || + ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || kv_format == NVTE_QKV_Format::NVTE_BHSD || + kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -428,7 +431,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && window_size_right == -1 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && (window_size_right == -1 || window_size_right >= 0) && + ((window_size_left == -1 || window_size_left >= 0) && + (window_size_right == -1 || window_size_right >= 0) && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && @@ -537,19 +541,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -583,8 +585,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } - size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; - size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] + : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] + : input_K->data.shape[ndim_kv - 2]; int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -648,9 +652,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, - dropout, qkv_layout, o_format, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, + attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, + input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); @@ -668,11 +673,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -708,8 +714,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } - size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim-3] : input_Q->data.shape[ndim - 2]; - size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv-3] : input_K->data.shape[ndim_kv - 2]; + size_t h_q = (q_format == NVTE_QKV_Format::NVTE_BHSD) ? input_Q->data.shape[ndim - 3] + : input_Q->data.shape[ndim - 2]; + size_t h_kv = (kv_format == NVTE_QKV_Format::NVTE_BHSD) ? input_K->data.shape[ndim_kv - 3] + : input_K->data.shape[ndim_kv - 2]; auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -762,13 +770,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); const Tensor *input_dO_f16; if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { - input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[3]); } fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, - input_dO, input_dO_f16, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, + input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 003ca0051d..237f3bd66e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,13 +1652,15 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, + void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1676,14 +1678,18 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, - "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); try { FADescriptor_v1 descriptor{b, @@ -1772,40 +1778,40 @@ void fused_attn_fp8_fwd_impl_v1( std::vector k_stride(4); std::vector v_stride(4); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_V_Matrix); Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_data_type(qkv_tensor_type)); + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_data_type(qkv_tensor_type)); + .set_name("K") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_data_type(qkv_tensor_type)); + .set_name("V") + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o if (is_delayed_scaling || is_current_scaling) { descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); @@ -1824,27 +1830,33 @@ void fused_attn_fp8_fwd_impl_v1( std::vector k_scale_strides(4); std::vector v_scale_strides(4); auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, v_scale_strides.data(), kv_format, false); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) - .set_stride(q_scale_strides) - .set_data_type(fe::DataType_t::FP8_E8M0) - .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_k") - .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) - .set_stride(k_scale_strides) - .set_data_type(fe::DataType_t::FP8_E8M0) - .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_v") - .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) - .set_stride(v_scale_strides) - .set_data_type(fe::DataType_t::FP8_E8M0) - .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, + v_scale_strides.data(), kv_format, false); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_attributes sdpa_options; @@ -1854,7 +1866,7 @@ void fused_attn_fp8_fwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - fe::DiagonalAlignment_t const &diagonal_alignment = + fe::DiagonalAlignment_t const& diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT : fe::DiagonalAlignment_t::TOP_LEFT; sdpa_options.set_diagonal_alignment(diagonal_alignment); @@ -1910,21 +1922,24 @@ void fused_attn_fp8_fwd_impl_v1( Stats = outputs[1]; amax_o = outputs[2]; } else { - auto outputs = mha_graph->sdpa_fp8( - Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, sdpa_options); O = outputs[0]; Stats = outputs[1]; amax_s = outputs[2]; amax_o = outputs[3]; amax_s->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); } std::vector o_stride(4); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); - O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride).set_data_type(o_tensor_type); + O->set_output(true) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -1949,9 +1964,10 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // amax_s std::shared_ptr> // amax_o key_tensors_tuple = - is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, nullptr, attn_scale, O, nullptr, amax_o) : - std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, attn_scale, O, amax_s, amax_o); + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, + nullptr, attn_scale, O, nullptr, amax_o) + : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = @@ -2040,20 +2056,23 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, - void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, - void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, - void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, - void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, - void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, + void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, + void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, + void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, + void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, + void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, + void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2070,14 +2089,18 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || - dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - NVTE_CHECK(is_delayed_scaling || is_current_scaling || is_mxfp8, - "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2179,8 +2202,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, attn_scale; - std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, + attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, + descale_v; std::shared_ptr descale_s, descale_o; std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; @@ -2194,11 +2219,11 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_stride(4); std::vector o_stride(4); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, q_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, k_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, v_stride.data(), qkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_V_Matrix); generateMatrixStridesWithFormat(b, h, s_q, d_v, o_stride.data(), o_format, false); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -2277,17 +2302,28 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); - printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], q_t_stride[3]); - printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], k_t_stride[3]); - printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], dO_t_stride[3]); - printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, cudnn_frontend::DataType_t::BFLOAT16); - printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); - printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, NVTE_QKV_Format::NVTE_BHSD); - printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], + q_t_stride[3]); + printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], + k_t_stride[3]); + printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], + dO_t_stride[3]); + printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, + cudnn_frontend::DataType_t::BFLOAT16); + printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, + cudnn_frontend::DataType_t::FP8_E5M2); + printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); printf("b: %d\n", b); printf("h: %d\n", h); printf("hg: %d\n", hg); @@ -2304,25 +2340,25 @@ void fused_attn_fp8_bwd_impl_v1( printf("is_dropout: %d\n", is_dropout); printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q_t") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_t_stride) - .set_data_type(qkv_tensor_type)); + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_stride) + .set_data_type(qkv_tensor_type)); K_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K_t") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_t_stride) - .set_data_type(qkv_tensor_type)); + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_stride) + .set_data_type(qkv_tensor_type)); dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_t") - .set_dim({b, h, s_q, d_v}) - .set_stride(dO_t_stride) - .set_data_type(do_tensor_type)); + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_stride) + .set_data_type(do_tensor_type)); dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO_f16") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_data_type(o_tensor_type)); + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride) + .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); printf("s_q_padded: %d\n", padded.s_q_padded); @@ -2344,57 +2380,78 @@ void fused_attn_fp8_bwd_impl_v1( std::vector v_scale_strides(4); std::vector dO_scale_strides(4); std::vector dO_t_scale_strides(4); - generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, q_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, q_t_scale_strides.data(), q_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, k_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, k_t_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, v_scale_strides.data(), kv_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, dO_scale_strides.data(), d_out_format, false); - generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); - printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], q_scale_strides[2], q_scale_strides[3]); - printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], q_t_scale_strides[2], q_t_scale_strides[3]); - printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], k_scale_strides[2], k_scale_strides[3]); - printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], k_t_scale_strides[2], k_t_scale_strides[3]); - printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], v_scale_strides[2], v_scale_strides[3]); - printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], dO_scale_strides[2], dO_scale_strides[3]); - printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], dO_t_scale_strides[2], dO_t_scale_strides[3]); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, + q_t_scale_strides.data(), q_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, + k_t_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, + v_scale_strides.data(), kv_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, + dO_scale_strides.data(), d_out_format, false); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, + dO_t_scale_strides.data(), d_out_format, false); + printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], + q_scale_strides[2], q_scale_strides[3]); + printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], + q_t_scale_strides[2], q_t_scale_strides[3]); + printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], + k_scale_strides[2], k_scale_strides[3]); + printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], + k_t_scale_strides[2], k_t_scale_strides[3]); + printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], + v_scale_strides[2], v_scale_strides[3]); + printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], + dO_scale_strides[2], dO_scale_strides[3]); + printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], + dO_t_scale_strides[2], dO_t_scale_strides[3]); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) .set_stride(q_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_q_t = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q_t") .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) .set_stride(q_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_k = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k") .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) .set_stride(k_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_k_t = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_k_t = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_k_t") .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) .set_stride(k_t_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_v = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_v") .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) .set_stride(v_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_dO = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_dO = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO") .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) .set_stride(dO_scale_strides) .set_data_type(fe::DataType_t::FP8_E8M0) .set_reordering_type(fe::TensorReordering_t::F8_128x4)); - descale_dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + descale_dO_t = + mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_dO_t") .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) .set_stride(dO_t_scale_strides) @@ -2408,7 +2465,7 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - fe::DiagonalAlignment_t const &diagonal_alignment = + fe::DiagonalAlignment_t const& diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT : fe::DiagonalAlignment_t::TOP_LEFT; sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); @@ -2474,9 +2531,10 @@ void fused_attn_fp8_bwd_impl_v1( } std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; if (is_delayed_scaling || is_current_scaling) { - auto outputs = mha_graph->sdpa_fp8_backward( - Q, K, V, O, dO, Stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + auto outputs = mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, + scale_dV, scale_dP, sdpa_backward_options); dQ = outputs[0]; dK = outputs[1]; dV = outputs[2]; @@ -2487,8 +2545,8 @@ void fused_attn_fp8_bwd_impl_v1( } if (is_mxfp8) { auto outputs = mha_graph->sdpa_fp8_backward( - Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, descale_k_t, - descale_v, descale_dO, descale_dO_t, sdpa_backward_options); + Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, + descale_k_t, descale_v, descale_dO, descale_dO_t, sdpa_backward_options); dQ = outputs[0]; dK = outputs[1]; dV = outputs[2]; @@ -2500,17 +2558,26 @@ void fused_attn_fp8_bwd_impl_v1( std::vector dk_stride(4); std::vector dv_stride(4); generateMatrixStrides_v1(b, h, hg, s_q, s_kv, d_qk, d_v, dq_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); + NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dk_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); + NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + NVTE_QKV_Matrix::NVTE_V_Matrix); printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); - dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(dq_stride).set_data_type(dqkv_tensor_type); - dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(dk_stride).set_data_type(dqkv_tensor_type); - dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(dv_stride).set_data_type(dqkv_tensor_type); + dQ->set_output(true) + .set_dim({b, h, s_q, d_qk}) + .set_stride(dq_stride) + .set_data_type(dqkv_tensor_type); + dK->set_output(true) + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(dk_stride) + .set_data_type(dqkv_tensor_type); + dV->set_output(true) + .set_dim({b, hg, s_kv, d_v}) + .set_stride(dv_stride) + .set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2524,10 +2591,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); if (!is_mxfp8) { - amax_dP->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); } std::tuple, // Q @@ -2560,8 +2627,9 @@ void fused_attn_fp8_bwd_impl_v1( Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); - auto mxfp8_tensors_tuple = is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + auto mxfp8_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -2574,16 +2642,18 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, + bias_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, - bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, + descale_k_t, descale_dO_t, bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = + get_graph(sdpa_fp8_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2650,23 +2720,34 @@ void fused_attn_fp8_bwd_impl_v1( printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); - printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, is_aligned_modulo(devPtrDescaleQ, modulo)); - printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, is_aligned_modulo(devPtrDescaleK, modulo)); - printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, is_aligned_modulo(devPtrDescaleV, modulo)); - printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, is_aligned_modulo(devPtrDescaledO, modulo)); - printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, is_aligned_modulo(devPtrDescaledO_t, modulo)); + printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, + is_aligned_modulo(devPtrDescaleQ, modulo)); + printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, + is_aligned_modulo(devPtrDescaleK, modulo)); + printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, + is_aligned_modulo(devPtrDescaleV, modulo)); + printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, + is_aligned_modulo(devPtrDescaledO, modulo)); + printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, + is_aligned_modulo(devPtrDescaledO_t, modulo)); printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); - printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, is_aligned_modulo(devPtrAmaxdQ, modulo)); - printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, is_aligned_modulo(devPtrAmaxdK, modulo)); - printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, is_aligned_modulo(devPtrAmaxdV, modulo)); + printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, + is_aligned_modulo(devPtrAmaxdQ, modulo)); + printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, + is_aligned_modulo(devPtrAmaxdK, modulo)); + printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, + is_aligned_modulo(devPtrAmaxdV, modulo)); printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); - printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, is_aligned_modulo(devPtrdO_f16, modulo)); + printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, + is_aligned_modulo(devPtrdO_f16, modulo)); printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); - printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, is_aligned_modulo(devPtrDescaleQ_t, modulo)); - printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, is_aligned_modulo(devPtrDescaleK_t, modulo)); + printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, + is_aligned_modulo(devPtrDescaleQ_t, modulo)); + printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, + is_aligned_modulo(devPtrDescaleK_t, modulo)); /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -2709,14 +2790,16 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, + const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, + Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, + const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = nullptr; void* devPtrK = nullptr; @@ -2753,7 +2836,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrScaleO = output_O->scale.dptr; devPtrAmaxS = input_output_S->amax.dptr; devPtrScaleS = input_output_S->scale.dptr; - devPtrDescaleS = input_output_S->scale_inv.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; } void* devPtrM = nullptr; void* devPtrZInv = nullptr; @@ -2796,13 +2879,15 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, - attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, + devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { @@ -2831,16 +2916,19 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, - const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, - const Tensor* input_dO, const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, - const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, - const Tensor* output_dK, const Tensor* output_dV, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, bool deterministic, + const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, + const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, + Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, + const Tensor* output_dV, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2899,19 +2987,21 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); - if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, - p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, - devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, + devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, + devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, + devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), + get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, workspace->data.dptr, + &workspace_size, stream, handle); } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 9683974a26..98d5876ec8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,26 +15,31 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -// fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, - const Tensor *output_dK, const Tensor *output_dV, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +// fused attention BWD FP8 with separate Q, K, V +void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, bool deterministic, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, + const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, + Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, + const Tensor *output_dV, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 8a9399e830..e67ae5e206 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -319,7 +319,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_transpose_dim_idx] = 1; } break; -} + } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { strideA[seqlen_kv_dim_idx] = 1; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 2c03245560..3e4ca696e2 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -49,11 +49,11 @@ struct MXFP8PaddedSizes { int64_t d_v_scale_padded; }; -inline bool is_aligned_modulo(void* ptr, int64_t modulo) { - // Cast the pointer to a large enough integer type (uintptr_t) - uintptr_t address = reinterpret_cast(ptr); - // Check if the address is perfectly divisible by 16 - return (address % modulo) == 0; +inline bool is_aligned_modulo(void *ptr, int64_t modulo) { + // Cast the pointer to a large enough integer type (uintptr_t) + uintptr_t address = reinterpret_cast(ptr); + // Check if the address is perfectly divisible by 16 + return (address % modulo) == 0; } // Pad s and d for MXFP8 layout @@ -78,7 +78,8 @@ inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_q // Get matrix strides for a 4D tensor [batch, head, seqlen, hidden] given a QKV format. // strideA must point to at least 4 int64_t elements. inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, - int64_t *strides, NVTE_QKV_Format format, bool transpose) { + int64_t *strides, NVTE_QKV_Format format, + bool transpose) { constexpr int batch_dim_idx = 0; constexpr int head_dim_idx = 1; int seqlen_dim_idx = transpose ? 3 : 2; @@ -111,234 +112,233 @@ inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int } // get matrix strides based on matrix type -inline void generateMatrixStrides_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t *strides, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) -{ - constexpr int batch_dim_idx = 0; - constexpr int head_dim_idx = 1; - bool transpose = - (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); - int seqlen_dim_idx = transpose ? 3 : 2; - int hidden_dim_idx = transpose ? 2 : 3; - constexpr int seqlen_q_dim_idx = 2; - constexpr int seqlen_kv_dim_idx = 3; - - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { - matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; - } - NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); - - switch (layout) { - case NVTE_QKV_Layout::NVTE_SB3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 3 * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 3 * h * d_qk; - strides[head_dim_idx] = 3 * d_qk; - strides[seqlen_dim_idx] = b * 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 2 * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = 2 * hg * d_qk; - strides[head_dim_idx] = 2 * d_qk; - strides[seqlen_dim_idx] = b * 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = b * hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BS3HD: - case NVTE_QKV_Layout::NVTE_T3HD: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_q * 3 * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSH3D: - case NVTE_QKV_Layout::NVTE_TH3D: - if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_q * 3 * h * d_qk; - strides[head_dim_idx] = 3 * d_qk; - strides[seqlen_dim_idx] = 3 * h * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: - case NVTE_QKV_Layout::NVTE_THD_T2HD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: - case NVTE_QKV_Layout::NVTE_THD_TH2D: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || - (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { - strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; - strides[head_dim_idx] = 2 * d_qk; - strides[seqlen_dim_idx] = 2 * hg * d_qk; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_THD_THD_THD: - case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = s_kv * hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = s_q * h * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = h * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * d_qk; - strides[head_dim_idx] = d_qk; - strides[seqlen_dim_idx] = b * hg * d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * d_v; - strides[head_dim_idx] = d_v; - strides[seqlen_dim_idx] = b * hg * d_v; - strides[hidden_dim_idx] = 1; - } - break; - case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: - if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { - strides[batch_dim_idx] = h * s_q * d_qk; - strides[head_dim_idx] = s_q * d_qk; - strides[seqlen_dim_idx] = d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { - strides[batch_dim_idx] = hg * s_kv * d_qk; - strides[head_dim_idx] = s_kv * d_qk; - strides[seqlen_dim_idx] = d_qk; - strides[hidden_dim_idx] = 1; - } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { - strides[batch_dim_idx] = hg * s_kv * d_v; - strides[head_dim_idx] = s_kv * d_v; - strides[seqlen_dim_idx] = d_v; - strides[hidden_dim_idx] = 1; - } - break; - default: - NVTE_CHECK(false, "Invalid layout."); - break; - } +inline void generateMatrixStrides_v1(int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, + int64_t d_qk, int64_t d_v, int64_t *strides, + NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + bool transpose = (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); + int seqlen_dim_idx = transpose ? 3 : 2; + int hidden_dim_idx = transpose ? 2 : 3; + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; + + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_Q_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_K_Matrix; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose) { + matrix = NVTE_QKV_Matrix::NVTE_V_Matrix; + } + NVTE_CHECK(matrix != NVTE_QKV_Matrix::NVTE_O_Matrix, + "Invalid matrix type. Expected Q, K, V, O, or their related transposes."); + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = b * 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = b * 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_q * 3 * h * d_qk; + strides[head_dim_idx] = 3 * d_qk; + strides[seqlen_dim_idx] = 3 * h * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strides[batch_dim_idx] = s_kv * 2 * hg * d_qk; + strides[head_dim_idx] = 2 * d_qk; + strides[seqlen_dim_idx] = 2 * hg * d_qk; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = s_kv * hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = s_q * h * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = h * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * d_qk; + strides[head_dim_idx] = d_qk; + strides[seqlen_dim_idx] = b * hg * d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * d_v; + strides[head_dim_idx] = d_v; + strides[seqlen_dim_idx] = b * hg * d_v; + strides[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if (matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) { + strides[batch_dim_idx] = h * s_q * d_qk; + strides[head_dim_idx] = s_q * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_qk; + strides[head_dim_idx] = s_kv * d_qk; + strides[seqlen_dim_idx] = d_qk; + strides[hidden_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) { + strides[batch_dim_idx] = hg * s_kv * d_v; + strides[head_dim_idx] = s_kv * d_v; + strides[seqlen_dim_idx] = d_v; + strides[hidden_dim_idx] = 1; + } + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { strides[seqlen_kv_dim_idx] = 1; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 6ee4d1a8ba..90393ce8c8 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -207,7 +207,9 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * * \return The destination shape. */ - void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, size_t *h, size_t *s, size_t *d, size_t *t); +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t); /*! \brief Get fused attention backend based on input parameters. * @@ -305,19 +307,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -391,11 +391,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 68fe616a93..a6cb036a35 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -450,7 +450,8 @@ enum NVTEGroupedTensorParam { kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ <<<<<<< HEAD - kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ + kNVTEGroupedWithGEMMSwizzledScales = + 10, /*!< Whether scaling factors are in format expected by GEMM */ ======= kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ @@ -517,8 +518,8 @@ void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTe * \param[in] param The value to be set (NVTEBasicTensor). */ void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, - const void *buf, size_t size_in_bytes); - + const void *buf, size_t size_in_bytes); + /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get a value of the parameter of the grouped tensor. * @@ -527,7 +528,9 @@ void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTenso * * \return NVTEBasicTensor containing the parameter data. */ -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, void *buf, size_t size_in_bytes, size_t *size_written); +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, + NVTEGroupedTensorParam param_name, void *buf, + size_t size_in_bytes, size_t *size_written); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. @@ -992,9 +995,9 @@ class TensorWrapper { */ <<<<<<< HEAD - class GroupedTensorWrapper { - public: - /*! \brief Constructs new GroupedTensorWrapper. +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. * * Create a new TE grouped tensor with a given logical shape. * TE grouped tensors are just wrappers on top of raw data and do not @@ -1004,11 +1007,11 @@ class TensorWrapper { * \param[in] logical_shape Logical 2D shape of the grouped data. * \param[in] scaling_mode Tensor data format. */ - GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} - - /*! \brief Constructs new GroupedTensorWrapper. + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. * * Create a new TE grouped tensor with a given logical shape. * @@ -1016,194 +1019,196 @@ class TensorWrapper { * \param[in] logical_shape Logical 2D shape of the grouped data. * \param[in] scaling_mode Tensor data format. */ - GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : GroupedTensorWrapper(num_tensors, - nvte_make_shape(logical_shape.data(), logical_shape.size()), - scaling_mode) {} - - /*! \brief GroupedTensorWrapper destructor. */ - ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } - - GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; - GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; - - /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ - GroupedTensorWrapper(GroupedTensorWrapper &&other) { - tensor_ = other.tensor_; - other.tensor_ = nullptr; - } - - /*! \brief Assign the data from existing GroupedTensorWrapper. */ - GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { - if (this == &other) return *this; - nvte_destroy_grouped_tensor(tensor_); - tensor_ = other.tensor_; - other.tensor_ = nullptr; - return *this; - } - - // Parameter setters - template - GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, - const ShapeType &shape) noexcept { - NVTEShape nvte_shape = this->convertShape(shape); - NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; - nvte_set_grouped_tensor_param(&tensor_, param, &data); - return *this; - } - - template - GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedScale, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); - } + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } - void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { const auto val = static_cast(with_gemm_swizzled_scales); - nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val)); - } - - // Parameter getters - NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { - return nvte_get_grouped_tensor_param(tensor_, param); - } - - NVTEBasicTensor get_rowwise_data() const noexcept { - return get_parameter(kNVTEGroupedRowwiseData); - } - - NVTEBasicTensor get_columnwise_data() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseData); - } - - NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } - - NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } - - NVTEBasicTensor get_rowwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedRowwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_amax() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseAmax); - } - - NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } - - NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } - - NVTEBasicTensor get_tensor_offsets() const noexcept { - return get_parameter(kNVTEGroupedTensorOffsets); - } - - bool get_with_gemm_swizzled_scales() const { + nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, + sizeof(val)); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + bool get_with_gemm_swizzled_scales() const { uint8_t val = 0; - nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), nullptr); + nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), + nullptr); return static_cast(val); } - /*! \brief Get an underlying NVTEGroupedTensor. + /*! \brief Get an underlying NVTEGroupedTensor. * * \return NVTEGroupedTensor held by this GroupedTensorWrapper. */ - NVTEGroupedTensor data() const noexcept { return tensor_; } - - /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ - size_t num_tensors() const noexcept { - if (tensor_ == nullptr) return 0; - return nvte_grouped_tensor_num_tensors(tensor_); - } - - /*! \brief Get the data type of this GroupedTensorWrapper. */ - DType dtype() const noexcept { - if (tensor_ == nullptr) return DType::kNumTypes; - return static_cast(nvte_grouped_tensor_type(tensor_)); - } - - /*! \brief Get a scaling mode of the grouped tensor. */ - NVTEScalingMode scaling_mode() const noexcept { - if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; - return nvte_grouped_tensor_scaling_mode(tensor_); - } - - /*! \brief Get the logical shape of this GroupedTensorWrapper. */ - const NVTEShape logical_shape() const noexcept { - if (tensor_ == nullptr) { - return emptyShape; - } - return nvte_get_grouped_tensor_logical_shape(tensor_); - } - - static constexpr size_t defaultData = 1; - static constexpr NVTEShape defaultShape = { - {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - - private: - NVTEShape convertShape(const NVTEShape &s) { return s; } - - NVTEShape convertShape(const std::vector &s) { - return nvte_make_shape(s.data(), s.size()); - } - - /*! \brief Wrapped NVTEGroupedTensor. */ - NVTEGroupedTensor tensor_ = nullptr; - }; - + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; ======= diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 12be47b638..18577b0eb4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -87,6 +87,7 @@ class Recipe: """ Base recipe class. """ + @classmethod def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index c7bbe4d974..be7521ccd4 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1359,116 +1359,116 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) return t.logical_shape; } -void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, const void *buf, - size_t size_in_bytes) { -// Check attribute and buffer -NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), -")"); -NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); -auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); - -// Read from buffer -switch (param) { -case kNVTEGroupedRowwiseData: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.data = *basic_tensor; -break; -} -case kNVTEGroupedColumnwiseData: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.columnwise_data = *basic_tensor; -break; -} -case kNVTEGroupedScale: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.scale = *basic_tensor; -break; -} -case kNVTEGroupedAmax: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.amax = *basic_tensor; -break; -} -case kNVTEGroupedRowwiseScaleInv: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.scale_inv = *basic_tensor; -break; -} -case kNVTEGroupedColumnwiseScaleInv: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.columnwise_scale_inv = *basic_tensor; -break; -} -case kNVTEGroupedColumnwiseAmax: { -const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -t.columnwise_amax = *basic_tensor; -break; -} -case kNVTEGroupedWithGEMMSwizzledScales: -t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); -break; -default: -NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); -} +void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); + NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); + auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + + // Read from buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.data = *basic_tensor; + break; + } + case kNVTEGroupedColumnwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_data = *basic_tensor; + break; + } + case kNVTEGroupedScale: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale = *basic_tensor; + break; + } + case kNVTEGroupedAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.amax = *basic_tensor; + break; + } + case kNVTEGroupedRowwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale_inv = *basic_tensor; + break; + } + case kNVTEGroupedColumnwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_scale_inv = *basic_tensor; + break; + } + case kNVTEGroupedColumnwiseAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_amax = *basic_tensor; + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); + break; + default: + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); + } } -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, void *buf, - size_t size_in_bytes, size_t *size_written) { -using namespace transformer_engine; +void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + void *buf, size_t size_in_bytes, size_t *size_written) { + using namespace transformer_engine; -// Check param -NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", static_cast(param), -")"); + // Check param + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); -// Return immediately if buffer is not provided -if (buf == nullptr) { -return; -} + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } -// Get C++ tensor -const GroupedTensor *t = convertNVTEGroupedTensor(tensor); + // Get C++ tensor + const GroupedTensor *t = convertNVTEGroupedTensor(tensor); -// Write to buffer -switch (param) { -case kNVTEGroupedRowwiseData: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->data); -break; -} -case kNVTEGroupedColumnwiseData: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->columnwise_data); -break; -} -case kNVTEGroupedScale: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->scale); -break; -} -case kNVTEGroupedAmax: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->amax); -break; -} -case kNVTEGroupedRowwiseScaleInv: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->scale_inv); -break; -} -case kNVTEGroupedColumnwiseScaleInv: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->columnwise_scale_inv); -break; -} -case kNVTEGroupedColumnwiseAmax: { -NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); -*basic_tensor = static_cast(t->columnwise_amax); -break; -} -case kNVTEGroupedWithGEMMSwizzledScales: -*reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); -break; -default: -NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); -} + // Write to buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->data); + break; + } + case kNVTEGroupedColumnwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_data); + break; + } + case kNVTEGroupedScale: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale); + break; + } + case kNVTEGroupedAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->amax); + break; + } + case kNVTEGroupedRowwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale_inv); + break; + } + case kNVTEGroupedColumnwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_scale_inv); + break; + } + case kNVTEGroupedColumnwiseAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_amax); + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); + break; + default: + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); + } } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 08b9dca6d7..e3aacbf2e6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -175,15 +175,26 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - assert qkv_layout == "sbhd_sbhd_sbhd", "sbhd_sbhd_sbhd is assumed to be the shape always at this point in UnfusedDotProductAttention." + assert qkv_layout == "sbhd_sbhd_sbhd", ( + "sbhd_sbhd_sbhd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( - qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype, des_nominal_dtype=query_layer.dtype + qkv_layout, + q_fp8, + k_fp8, + v_fp8, + src_nominal_dtype=query_layer.dtype, + des_nominal_dtype=query_layer.dtype, ) if isinstance(quantizer, MXFP8Quantizer): - assert qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + assert qkv_layout == "bhsd_bhsd_bhsd", ( + "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) # permute back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: @@ -217,7 +228,10 @@ def backward(ctx, grad1, grad2, grad3): ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) if isinstance(ctx.quantizer, MXFP8Quantizer): - assert ctx.qkv_layout == "bhsd_bhsd_bhsd", "bhsd_bhsd_bhsd is assumed to be the shape always at this point in UnfusedDotProductAttention." + assert ctx.qkv_layout == "bhsd_bhsd_bhsd", ( + "bhsd_bhsd_bhsd is assumed to be the shape always at this point in" + " UnfusedDotProductAttention." + ) # permute back to sbhd_sbhd_sbhd tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: @@ -469,7 +483,14 @@ def forward( fp8_dtype=dP_quantizer.dtype, device="cuda" ) # disable swizzle for MXFP8Quantizer - for q in [QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer]: + for q in [ + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ]: if isinstance(q, MXFP8Quantizer): q.optimize_for_gemm = False q.internal = False @@ -477,11 +498,21 @@ def forward( # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", "sbhd_sbhd_sbhd" + query_layer, + key_layer, + value_layer, + QKV_quantizer, + "QKV_quantizer", + "sbhd_sbhd_sbhd", ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", "sbhd_sbhd_sbhd" + query_layer, + key_layer, + value_layer, + dQKV_quantizer, + "dQKV_quantizer", + "sbhd_sbhd_sbhd", ) # [sq, b, np, hn] -> [sq, b * np, hn] @@ -1250,7 +1281,9 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) # print quantizers print_quantizers( @@ -1335,7 +1368,9 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance(QKV_quantizer, MXFP8Quantizer): + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance( + QKV_quantizer, MXFP8Quantizer + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 60b3f7fe71..de5563b7a6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -60,6 +60,7 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + def get_bsh_dims(tensor_format): """Get batch dimension and sequence dimension from tensor format""" if tensor_format in ["bshd", "sbhd", "bhsd"]: @@ -71,6 +72,7 @@ def get_bsh_dims(tensor_format): head_dim = tensor_format.index("h") return batch_dim, seq_dim, head_dim + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -468,7 +470,12 @@ def flash_attn_a2a_communicate( # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] # or [b, np, s, hn] -> [b, cp, np//cp, s, hn] # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:head_dim], cp_size, x.shape[head_dim] // cp_size, *x.shape[head_dim + 1:]) + x = x.view( + *x.shape[:head_dim], + cp_size, + x.shape[head_dim] // cp_size, + *x.shape[head_dim + 1 :], + ) # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] # or [b, cp, np//cp, s, hn] -> [cp, b, np//cp, s, hn] @@ -505,7 +512,7 @@ def flash_attn_a2a_communicate( # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] # or [cp, 2, b, np//cp, s//2, hn] -> [b, cp, np//cp, 2, s//2, hn] # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, head_dim+1).movedim(0, seq_dim+1).contiguous() + x = x.movedim(0, head_dim + 1).movedim(0, seq_dim + 1).contiguous() # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] # or [b, cp, np//cp, 2, s//2, hn] -> [b*np, s, hn] @@ -897,7 +904,9 @@ def cp_p2p_fwd_fused_attn( for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -934,7 +943,7 @@ def cp_p2p_fwd_fused_attn( if return_max_logit: return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit - return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None #, new_qkv_layout + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None # , new_qkv_layout def cp_p2p_fwd_flash_attn( @@ -1165,7 +1174,9 @@ def cp_p2p_bwd_fused_attn( ) ] else: - q_part, k_part, v_part, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step) + q_part, k_part, v_part, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step + ) if not fp8_recipe.mxfp8(): if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) @@ -1419,7 +1430,10 @@ def forward( q_fp8, k_fp8, v_fp8 = (None, None, None) # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}") + print( + f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}," + f" is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}" + ) if fp8 and is_input_fp8: QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v @@ -1450,7 +1464,9 @@ def forward( # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] @@ -1901,7 +1917,10 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}") + print( + f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}," + f" out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}" + ) if enable_mla: out = out.view(o_shape) else: @@ -1952,7 +1971,10 @@ def forward( ctx.batch_size = out.shape[1] print(f"========= {torch.cuda.current_device()}: out.shape: {out.shape} {out.dtype}") out_part = out.to(fwd_nominal_dtype) - print(f"========= {torch.cuda.current_device()}: out_part.shape: {out_part.shape} {out_part.dtype}") + print( + f"========= {torch.cuda.current_device()}: out_part.shape:" + f" {out_part.shape} {out_part.dtype}" + ) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) @@ -1997,7 +2019,11 @@ def forward( out_f16 = out.to(fwd_nominal_dtype) if fp8 and ( is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) and not fp8_recipe.mxfp8()) + or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() + ) ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 @@ -2052,8 +2078,14 @@ def forward( kv_f16 = kv f16_tensors = (q_f16, kv_f16, out_f16) if torch.cuda.current_device() == 0: - print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") - print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") + print( + "fp8_tensors:" + f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" + ) + print( + "f16_tensors:" + f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" + ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -2130,7 +2162,12 @@ def backward(ctx, dout, *_args): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8(): + if ( + ctx.fp8 + and ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2308,7 +2345,9 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + dP_quantizer_per_step[i] = ( + ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + ) dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() if not ctx.fp8_recipe.mxfp8(): dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) @@ -2328,7 +2367,10 @@ def backward(ctx, dout, *_args): # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print(f"========= {torch.cuda.current_device()}: before a2a: out.shape: {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}") + print( + f"========= {torch.cuda.current_device()}: before a2a: out.shape:" + f" {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}" + ) if not ctx.use_fused_attention: # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) @@ -2344,7 +2386,10 @@ def backward(ctx, dout, *_args): ctx.cp_stream, True, ) - print(f"========= {torch.cuda.current_device()}: after a2a: dout.shape: {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}") + print( + f"========= {torch.cuda.current_device()}: after a2a: dout.shape:" + f" {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}" + ) if ctx.enable_mla: out = out.view(*ctx.o_shape) @@ -2449,7 +2494,8 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8() + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or ctx.fp8_recipe.mxfp8() else out_fp8 ), dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, @@ -3010,7 +3056,9 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs["s_quantizer"] = S_quantizer @@ -3090,11 +3138,19 @@ def forward( new_qkv_layout = qkv_layout if fp8: if not fp8_recipe.mxfp8(): - q_part = Float8Tensor.make_like(q_fp8, data=q_, dtype=fwd_nominal_dtype) - k_part = Float8Tensor.make_like(k_fp8, data=k_, dtype=fwd_nominal_dtype) - v_part = Float8Tensor.make_like(v_fp8, data=v_, dtype=fwd_nominal_dtype) + q_part = Float8Tensor.make_like( + q_fp8, data=q_, dtype=fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_, dtype=fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_, dtype=fwd_nominal_dtype + ) else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) ( out_per_step[i], aux_ctx_tensors, @@ -3217,8 +3273,14 @@ def forward( else: f16_tensors = (q, k, v, out) if torch.cuda.current_device() == 0: - print(f"fp8_tensors: {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}") - print(f"f16_tensors: {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}") + print( + "fp8_tensors:" + f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" + ) + print( + "f16_tensors:" + f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" + ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3282,8 +3344,14 @@ def backward(ctx, dout, *_args): softmax_lse_per_step = [None, None] rng_states = [None, None] ( - q_fp8, k_fp8, v_fp8, out_fp8, - q, k, v, out, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + q, + k, + v, + out, cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv_per_step[0], @@ -3320,8 +3388,16 @@ def backward(ctx, dout, *_args): if torch.cuda.current_device() == 0: print(f"ctx.q_shape: {ctx.q_shape} {ctx.k_shape} {ctx.v_shape}") dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) - dk = torch.zeros((ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=k.device) - dv = torch.zeros((ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), dtype=ctx.fwd_nominal_dtype, device=v.device) + dk = torch.zeros( + (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=k.device, + ) + dv = torch.zeros( + (ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=v.device, + ) if torch.cuda.current_device() == 0: print(f"dq: {dq.shape} {dq.dtype} {dq.device}") print(f"dk: {dk.shape} {dk.dtype} {dk.device}") @@ -3400,7 +3476,11 @@ def backward(ctx, dout, *_args): out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], softmax_lse_per_step[i], rng_states[i]] + aux_ctx_tensors = [ + softmax_lse_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + ] q_part, k_part, v_part, out_part, dout_part = q_, k_, v_, out_, dout_ fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} @@ -3421,17 +3501,34 @@ def backward(ctx, dout, *_args): v_part = Float8Tensor.make_like( v_fp8, data=v_, dtype=ctx.fwd_nominal_dtype ) - if not (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + if not ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ): out_part = ctx.O_quantizer(out_part) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype) + dout_part = Float8Tensor.make_like( + dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype + ) else: - q_part, k_part, v_part, new_qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer) - print(f"aux_ctx_tensors: {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}") - dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer + ) + print( + "aux_ctx_tensors:" + f" {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}" + ) + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( + d_out_format, dout_part + ) aux_ctx_tensors.append(dout_part) dout_part = ctx.dO_quantizer(dout_part) - print(f"q_part type: {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}") - print(f"q_part shape: {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}") + print( + "q_part type:" + f" {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}" + ) + print( + "q_part shape:" + f" {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}" + ) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, @@ -3463,7 +3560,11 @@ def backward(ctx, dout, *_args): ) if ctx.fp8: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - x.dequantize(dtype=ctx.fwd_nominal_dtype) if isinstance(x, QuantizedTensorStorage) else x + ( + x.dequantize(dtype=ctx.fwd_nominal_dtype) + if isinstance(x, QuantizedTensorStorage) + else x + ) for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] ] else: @@ -3506,7 +3607,10 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if torch.cuda.current_device() == 0: - print(f"dq.shape: {dq.shape} dq_per_step[i - 1].shape: {dq_per_step[i - 1].shape}") + print( + f"dq.shape: {dq.shape} dq_per_step[i - 1].shape:" + f" {dq_per_step[i - 1].shape}" + ) if ctx.qkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": @@ -3717,13 +3821,15 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # else: # q, k, v = [q_fp8, k_fp8, v_fp8] - # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) - # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + # qkv_format, _, _ = dpa_utils.get_qkv_format(qkv_layout) + # batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["o_quantizer"] = O_quantizer @@ -3763,7 +3869,9 @@ def forward( for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] if fp8 and fp8_recipe.mxfp8(): - q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize(qkv_layout, q_part, k_part, v_part, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, @@ -3995,7 +4103,7 @@ def backward(ctx, dout, *_args): dout_fp8 = dout if not ctx.fp8_recipe.mxfp8(): # dqkv_te_dtype = dout._fp8_dtype - dout = dout._data + dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer @@ -4075,7 +4183,9 @@ def backward(ctx, dout, *_args): q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or ctx.fp8_recipe.mxfp8(): + if ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or ctx.fp8_recipe.mxfp8(): out_part = out if not ctx.fp8_recipe.mxfp8(): dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) @@ -4174,7 +4284,9 @@ def backward(ctx, dout, *_args): ) if ctx.fp8: - if (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()) and ctx.is_input_fp8: + if ( + ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() + ) and ctx.is_input_fp8: dq, dk, dv = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 5b32f35be0..9a8d38547e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2209,21 +2209,21 @@ def print_quantizers( f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" ) else: - print( - f"{label} >> {names[i]:14s}: {type_str}" - ) + print(f"{label} >> {names[i]:14s}: {type_str}") + def permute_to_grouped_tensor(src_format, tensor): """Permute tensor to bhsd or htd format for grouped quantization in MXFP8BlockScaling. src_format ={bshd, sbhd, thd}""" if src_format in ["bhsd", "htd"]: return tensor, src_format tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor - dim_s_or_t = src_format.find("s") if 's' in src_format else src_format.find("t") + dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] perm = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] tensor = tensor.permute(*perm).contiguous() return tensor, "bhsd" if src_format != "thd" else "htd" + def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" # 1: qkv packed, 2: kv packed, 3: qkv separate @@ -2244,7 +2244,10 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] - print(f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}, s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}") + print( + f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}," + f" s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}" + ) assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 @@ -2254,7 +2257,9 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): # consider bhsd for now if d_qk == d_v: - grouped_tensor = GroupedTensor.create_and_quantize(tensors=[q, k, v], quantizer=qkv_quantizer) + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=[q, k, v], quantizer=qkv_quantizer + ) q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors else: q_fp8 = qkv_quantizer(q) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index ebed38fc84..f757bbdfee 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -306,7 +306,7 @@ class MXFP8Quantizer : public Quantizer { * amax to be initialized to zero. */ std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype, std::optional data = std::nullopt); + const std::vector& shape, DType dtype, std::optional data = std::nullopt); std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3d5ad2e598..95c985062a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -98,11 +98,13 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, + const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index fd193b0258..bd5a5a065f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -153,8 +153,9 @@ std::vector fused_attn_fwd( auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; - size_t b=0, h=0, s=0, d=0, t=0; - nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); + size_t b = 0, h = 0, s = 0, d = 0, t = 0; + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), o_shape_tmp, o_format, o_shape, &b, &h, &s, + &d, &t); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -254,9 +255,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -314,9 +315,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -328,11 +329,13 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, //const DType dqkv_type, + const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, //const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -365,11 +368,14 @@ std::vector fused_attn_bwd( // auto h_kv = k_shape[k_shape.size() - 2]; // auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - size_t b=0, h_q=0, h_kv=0, s_q=0, s_kv=0, d_qk=0, d_v=0, t_q=0, t_kv=0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; std::vector dQ_shape(4), dK_shape(4), dV_shape(4); - nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); - nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); - nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + nvte_convert_qkv_format(nvte_get_q_format(qkv_layout), q_shape, nvte_get_q_format(dqkv_layout), + dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), k_shape, nvte_get_kv_format(dqkv_layout), + dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(nvte_get_kv_format(qkv_layout), v_shape, nvte_get_kv_format(dqkv_layout), + dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); at::Tensor dQ, dK, dV, dQKV, dKV; DType dqkv_type = fake_dtype_te; @@ -380,7 +386,8 @@ std::vector fused_attn_bwd( if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); @@ -447,7 +454,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); @@ -566,9 +573,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -583,9 +590,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 92239aafc0..b44640d006 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1230,10 +1230,10 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } -std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax(const std::vector& shape, - DType dtype, - std::optional data) { - at::Tensor amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); +std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data) { + at::Tensor amax_tensor = + at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) : NoneQuantizer(py::none()).create_tensor(shape, dtype); TensorWrapper out_cpp = std::move(out.first); diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 255a9ecfd3..466429cf3f 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -975,4 +975,4 @@ def quantize( self.quantized_tensors = self.split_into_quantized_tensors() for i in range(self.num_tensors): self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) - return self.quantized_tensors \ No newline at end of file + return self.quantized_tensors From 81c18fa8fc854e5a977581f815e94dc4097d955e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:05:21 -0800 Subject: [PATCH 43/59] fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_grouped_tensor.py | 2 +- .../cast/mxfp8/group_quantize_mxfp8.cuh | 12 -- .../common/fused_attn/fused_attn_fp8.cu | 1 - .../transformer_engine/transformer_engine.h | 29 +---- .../common/transformer_engine.cpp | 114 ------------------ .../pytorch/csrc/extensions/attention.cpp | 6 +- .../pytorch/tensor/storage/grouped_tensor.py | 6 +- 7 files changed, 9 insertions(+), 161 deletions(-) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 31d84933de..de00d0cf35 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -393,7 +393,7 @@ def test_quantize_grouped_mxfp8(self) -> None: device="cuda", ) # Quantize using grouped API (handle both 2-arg and 3-arg bindings) - _ = tex.quantize_grouped(grouped_input, grouped_output) + _ = tex.group_quantize(grouped_input, grouped_output) # Build expected output by quantizing each tensor independently expected_data = [] expected_scale_inv = [] diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index c9816494bb..f454209409 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -787,7 +787,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations using namespace group_quantize_kernel; checkCuDriverContext(stream); - // CheckNoopTensor(*noop, "cast_noop"); const bool use_rowwise_scaling = output->has_data(); const bool use_colwise_scaling = output->has_columnwise_data(); @@ -800,13 +799,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } else if (!use_rowwise_scaling) { scaling_type = ScalingType::COLWISE; } - // if (use_rowwise_scaling && (!use_colwise_scaling)) { - // scaling_type = ScalingType::ROWWISE; - // } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - // scaling_type = ScalingType::COLWISE; - // } else if (use_rowwise_scaling && use_colwise_scaling) { - // scaling_type = ScalingType::BIDIMENSIONAL; - // } ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; if (output->all_same_shape()) { @@ -886,10 +878,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 237f3bd66e..48c3975264 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2084,7 +2084,6 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; - const auto cudnn_runtime_version = cudnnGetVersion(); auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index a6cb036a35..3dacc596c8 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -449,13 +449,15 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ +<<<<<<< HEAD <<<<<<< HEAD kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ ======= +======= +>>>>>>> 341cc3df (fix merge) kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ ->>>>>>> main kNVTENumGroupedTensorParams }; @@ -511,27 +513,6 @@ void nvte_set_grouped_tensor_param(NVTEGroupedTensor tensor, NVTEGroupedTensorPa void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, void *buf, size_t size_in_bytes, size_t *size_written); -/*! \brief Set a parameter of the grouped tensor. - * - * \param[in/out] tensor Grouped tensor. - * \param[in] param_name The parameter to be set. - * \param[in] param The value to be set (NVTEBasicTensor). - */ -void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param_name, - const void *buf, size_t size_in_bytes); - -/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Get a value of the parameter of the grouped tensor. - * - * \param[in] tensor Grouped tensor. - * \param[in] param_name The parameter to be queried. - * - * \return NVTEBasicTensor containing the parameter data. - */ -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, - NVTEGroupedTensorParam param_name, void *buf, - size_t size_in_bytes, size_t *size_written); - /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. * @@ -994,6 +975,7 @@ class TensorWrapper { * \brief C++ wrapper for the NVTEGroupedTensor class. */ +<<<<<<< HEAD <<<<<<< HEAD class GroupedTensorWrapper { public: @@ -1212,6 +1194,8 @@ class GroupedTensorWrapper { /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; ======= +======= +>>>>>>> 341cc3df (fix merge) class GroupedTensorWrapper { public: /*! \brief Constructs new GroupedTensorWrapper. @@ -1437,7 +1421,6 @@ enum class Float8BlockScaleTensorFormat { COMPACT = 1, INVALID }; ->>>>>>> main /*! \struct QuantizationConfigWrapper * \brief C++ wrapper for NVTEQuantizationConfigWrapper. diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index be7521ccd4..cd02074fbd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1358,117 +1358,3 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); return t.logical_shape; } - -void nvte_set_grouped_tensor_param_v2(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, - const void *buf, size_t size_in_bytes) { - // Check attribute and buffer - NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", - static_cast(param), ")"); - NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); - auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); - - // Read from buffer - switch (param) { - case kNVTEGroupedRowwiseData: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.data = *basic_tensor; - break; - } - case kNVTEGroupedColumnwiseData: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.columnwise_data = *basic_tensor; - break; - } - case kNVTEGroupedScale: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.scale = *basic_tensor; - break; - } - case kNVTEGroupedAmax: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.amax = *basic_tensor; - break; - } - case kNVTEGroupedRowwiseScaleInv: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.scale_inv = *basic_tensor; - break; - } - case kNVTEGroupedColumnwiseScaleInv: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.columnwise_scale_inv = *basic_tensor; - break; - } - case kNVTEGroupedColumnwiseAmax: { - const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - t.columnwise_amax = *basic_tensor; - break; - } - case kNVTEGroupedWithGEMMSwizzledScales: - t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); - break; - default: - NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); - } -} - -void nvte_get_grouped_tensor_param_v2(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, - void *buf, size_t size_in_bytes, size_t *size_written) { - using namespace transformer_engine; - - // Check param - NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", - static_cast(param), ")"); - - // Return immediately if buffer is not provided - if (buf == nullptr) { - return; - } - - // Get C++ tensor - const GroupedTensor *t = convertNVTEGroupedTensor(tensor); - - // Write to buffer - switch (param) { - case kNVTEGroupedRowwiseData: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->data); - break; - } - case kNVTEGroupedColumnwiseData: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->columnwise_data); - break; - } - case kNVTEGroupedScale: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->scale); - break; - } - case kNVTEGroupedAmax: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->amax); - break; - } - case kNVTEGroupedRowwiseScaleInv: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->scale_inv); - break; - } - case kNVTEGroupedColumnwiseScaleInv: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->columnwise_scale_inv); - break; - } - case kNVTEGroupedColumnwiseAmax: { - NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); - *basic_tensor = static_cast(t->columnwise_amax); - break; - } - case kNVTEGroupedWithGEMMSwizzledScales: - *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); - break; - default: - NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); - } -} diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bd5a5a065f..192a774ca0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -166,8 +166,6 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - // auto h = q_shape[q_shape.size() - 2]; - // auto d = q_shape[q_shape.size() - 1]; if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -319,6 +317,7 @@ std::vector fused_attn_fwd( attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); + // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -364,9 +363,6 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - // auto h_q = q_shape[q_shape.size() - 2]; - // auto h_kv = k_shape[k_shape.size() - 2]; - // auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; std::vector dQ_shape(4), dK_shape(4), dV_shape(4); diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index 466429cf3f..ef91e58e7c 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -53,11 +53,7 @@ class GroupedTensor: def __init__( self, num_tensors: int, -<<<<<<< HEAD shapes: List[Tuple[int, int]], -======= - shape: Optional[List[Tuple[int, int]]] = None, ->>>>>>> main quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, @@ -959,7 +955,7 @@ def create_and_quantize( dtype=dtype, ) - _ = tex.quantize_grouped(grouped_input, grouped_output) + _ = tex.group_quantize(grouped_input, grouped_output) grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() return grouped_output From 1f14f2fa6ba3493598d0386df5af7e01f74daa42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Mar 2026 23:08:21 +0000 Subject: [PATCH 44/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/storage/grouped_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index ef91e58e7c..c002067e11 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -970,5 +970,7 @@ def quantize( """ self.quantized_tensors = self.split_into_quantized_tensors() for i in range(self.num_tensors): - self.quantizer.update_quantized(tensors[i], self.quantized_tensors[i], noop_flag=noop_flag) + self.quantizer.update_quantized( + tensors[i], self.quantized_tensors[i], noop_flag=noop_flag + ) return self.quantized_tensors From ccebe771058024ac51cb335633136f6df779fb57 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:17:17 -0800 Subject: [PATCH 45/59] fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../transformer_engine/transformer_engine.h | 229 ------------------ 1 file changed, 229 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 3dacc596c8..635b9fdcce 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -449,13 +449,6 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ -<<<<<<< HEAD -<<<<<<< HEAD - kNVTEGroupedWithGEMMSwizzledScales = - 10, /*!< Whether scaling factors are in format expected by GEMM */ -======= -======= ->>>>>>> 341cc3df (fix merge) kNVTEGroupedWithGEMMSwizzledScales = 10, /*!< Whether scaling factors are in format expected by GEMM */ kNVTENumGroupedTensorParams @@ -974,228 +967,6 @@ class TensorWrapper { /*! \struct GroupedTensorWrapper * \brief C++ wrapper for the NVTEGroupedTensor class. */ - -<<<<<<< HEAD -<<<<<<< HEAD -class GroupedTensorWrapper { - public: - /*! \brief Constructs new GroupedTensorWrapper. - * - * Create a new TE grouped tensor with a given logical shape. - * TE grouped tensors are just wrappers on top of raw data and do not - * own memory. - * - * \param[in] num_tensors Number of tensors in the group (must be > 0). - * \param[in] logical_shape Logical 2D shape of the grouped data. - * \param[in] scaling_mode Tensor data format. - */ - GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} - - /*! \brief Constructs new GroupedTensorWrapper. - * - * Create a new TE grouped tensor with a given logical shape. - * - * \param[in] num_tensors Number of tensors in the group (must be > 0). - * \param[in] logical_shape Logical 2D shape of the grouped data. - * \param[in] scaling_mode Tensor data format. - */ - GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : GroupedTensorWrapper(num_tensors, - nvte_make_shape(logical_shape.data(), logical_shape.size()), - scaling_mode) {} - - /*! \brief GroupedTensorWrapper destructor. */ - ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } - - GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; - GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; - - /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ - GroupedTensorWrapper(GroupedTensorWrapper &&other) { - tensor_ = other.tensor_; - other.tensor_ = nullptr; - } - - /*! \brief Assign the data from existing GroupedTensorWrapper. */ - GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { - if (this == &other) return *this; - nvte_destroy_grouped_tensor(tensor_); - tensor_ = other.tensor_; - other.tensor_ = nullptr; - return *this; - } - - // Parameter setters - template - GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, - const ShapeType &shape) noexcept { - NVTEShape nvte_shape = this->convertShape(shape); - NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; - nvte_set_grouped_tensor_param(&tensor_, param, &data); - return *this; - } - - template - GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedScale, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); - } - - void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { - const auto val = static_cast(with_gemm_swizzled_scales); - nvte_set_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, - sizeof(val)); - } - - // Parameter getters - NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { - return nvte_get_grouped_tensor_param(tensor_, param); - } - - NVTEBasicTensor get_rowwise_data() const noexcept { - return get_parameter(kNVTEGroupedRowwiseData); - } - - NVTEBasicTensor get_columnwise_data() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseData); - } - - NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } - - NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } - - NVTEBasicTensor get_rowwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedRowwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_amax() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseAmax); - } - - NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } - - NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } - - NVTEBasicTensor get_tensor_offsets() const noexcept { - return get_parameter(kNVTEGroupedTensorOffsets); - } - - bool get_with_gemm_swizzled_scales() const { - uint8_t val = 0; - nvte_get_grouped_tensor_param_v2(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), - nullptr); - return static_cast(val); - } - - /*! \brief Get an underlying NVTEGroupedTensor. - * - * \return NVTEGroupedTensor held by this GroupedTensorWrapper. - */ - NVTEGroupedTensor data() const noexcept { return tensor_; } - - /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ - size_t num_tensors() const noexcept { - if (tensor_ == nullptr) return 0; - return nvte_grouped_tensor_num_tensors(tensor_); - } - - /*! \brief Get the data type of this GroupedTensorWrapper. */ - DType dtype() const noexcept { - if (tensor_ == nullptr) return DType::kNumTypes; - return static_cast(nvte_grouped_tensor_type(tensor_)); - } - - /*! \brief Get a scaling mode of the grouped tensor. */ - NVTEScalingMode scaling_mode() const noexcept { - if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; - return nvte_grouped_tensor_scaling_mode(tensor_); - } - - /*! \brief Get the logical shape of this GroupedTensorWrapper. */ - const NVTEShape logical_shape() const noexcept { - if (tensor_ == nullptr) { - return emptyShape; - } - return nvte_get_grouped_tensor_logical_shape(tensor_); - } - - static constexpr size_t defaultData = 1; - static constexpr NVTEShape defaultShape = { - {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - - private: - NVTEShape convertShape(const NVTEShape &s) { return s; } - - NVTEShape convertShape(const std::vector &s) { - return nvte_make_shape(s.data(), s.size()); - } - - /*! \brief Wrapped NVTEGroupedTensor. */ - NVTEGroupedTensor tensor_ = nullptr; -}; - -/*! \warning Deprecated */ -enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; -======= -======= ->>>>>>> 341cc3df (fix merge) class GroupedTensorWrapper { public: /*! \brief Constructs new GroupedTensorWrapper. From c52c5f41aafba3dd642ce5449f79379406fad9e2 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:33:26 -0800 Subject: [PATCH 46/59] revert to main grouped tensor impl Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_grouped_tensor.py | 117 +++++++++--------- .../pytorch/tensor/storage/grouped_tensor.py | 80 +++++------- 2 files changed, 91 insertions(+), 106 deletions(-) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index de00d0cf35..ad08c0474d 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -121,11 +121,11 @@ def setup_class(cls) -> None: def test_basic_construction_all_same_shape(self) -> None: """Test GroupedTensor construction with all tensors having same shape""" num_tensors = 4 - shapes = [(256, 512) for _ in range(num_tensors)] + shape = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -143,11 +143,11 @@ def test_basic_construction_all_same_shape(self) -> None: def test_basic_construction_varying_first_dim(self) -> None: """Test GroupedTensor construction with varying first dimension""" num_tensors = 3 - shapes = [(128, 512), (256, 512), (384, 512)] + shape = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -157,20 +157,20 @@ def test_basic_construction_varying_first_dim(self) -> None: assert not grouped_tensor.all_same_shape() assert not grouped_tensor.all_same_first_dim() assert grouped_tensor.all_same_last_dim() - assert grouped_tensor.get_common_last_dim() == shapes[0][1] + assert grouped_tensor.get_common_last_dim() == shape[0][1] assert grouped_tensor.logical_shape == ( - sum(v for v, _ in shapes), - shapes[0][1], + sum(v for v, _ in shape), + shape[0][1], ) # sum of first dims def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 - shapes = [(256, 512) for _ in range(num_tensors)] + shape = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -186,7 +186,7 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: # Verify each tensor has correct shape and shares storage for i, tensor in enumerate(tensors): - assert tensor.shape == shapes[i] + assert tensor.shape == shape[i] assert isinstance(tensor, torch.Tensor) assert not hasattr(tensor, "_data") # Not a quantized tensor @@ -195,19 +195,19 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: assert tensor.data_ptr() >= original_data_ptr # Calculate expected offset - expected_offset = i * (shapes[i][0] * shapes[i][1]) * tensor.element_size() + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset @pytest.mark.parametrize("quantization", _quantization_params) def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: """Test split_into_quantized_tensors for quantized tensors""" num_tensors = 3 - shapes = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=quantizer, device="cuda", ) @@ -225,18 +225,18 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None rowwise_data = _get_rowwise_data_tensor(tensor, quantization) assert rowwise_data is not None assert rowwise_data.data_ptr() >= original_data_ptr - numel = shapes[i][0] * shapes[i][1] + numel = shape[i][0] * shape[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset def test_split_varying_shapes(self) -> None: """Test split_into_quantized_tensors with varying shapes""" num_tensors = 3 - shapes = [(128, 512), (256, 512), (384, 512)] + shape = [(128, 512), (256, 512), (384, 512)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -250,21 +250,21 @@ def test_split_varying_shapes(self) -> None: # Verify shapes and storage cumulative_offset = 0 for i, tensor in enumerate(tensors): - assert tensor.shape == shapes[i] + assert tensor.shape == shape[i] expected_offset = cumulative_offset * tensor.element_size() assert tensor.data_ptr() == original_data_ptr + expected_offset - cumulative_offset += shapes[i][0] * shapes[i][1] + cumulative_offset += shape[i][0] * shape[i][1] @pytest.mark.parametrize("quantization", _quantization_params) def test_quantize_inplace(self, quantization: str) -> None: """Test that quantize is done in-place for all recipes""" num_tensors = 3 - shapes = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=quantizer, device="cuda", ) @@ -277,7 +277,7 @@ def test_quantize_inplace(self, quantization: str) -> None: ) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -291,7 +291,7 @@ def test_quantize_inplace(self, quantization: str) -> None: # Verify returned tensors point to the same storage for i, qtensor in enumerate(quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shapes[i][0] * shapes[i][1] + numel = shape[i][0] * shape[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -299,12 +299,12 @@ def test_quantize_inplace(self, quantization: str) -> None: def test_quantize_varying_shapes(self, quantization: str) -> None: """Test quantize with varying shapes""" num_tensors = 3 - shapes = [(256, 512), (512, 512), (768, 512)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(256, 512), (512, 512), (768, 512)] + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=quantizer, device="cuda", ) @@ -313,7 +313,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() # Create input tensors with varying shapes - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] # Quantize in place quantized_tensors = grouped_tensor.quantize(input_tensors) @@ -323,7 +323,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: # Verify each tensor points to correct location cumulative_numel = 0 - for qtensor, tensor_shape in zip(quantized_tensors, shapes): + for qtensor, tensor_shape in zip(quantized_tensors, shape): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset @@ -333,11 +333,11 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: def test_static_quantize_method(self, quantization: str) -> None: """Test the static quantize method""" num_tensors = 3 - shapes = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shapes) + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shapes] + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] # Use static quantize method grouped_tensor = GroupedTensor.create_and_quantize( @@ -357,43 +357,44 @@ def test_static_quantize_method(self, quantization: str) -> None: original_data_ptr = grouped_tensor.data.data_ptr() for i, qtensor in enumerate(grouped_tensor.quantized_tensors): rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shapes[i][0] * shapes[i][1] + numel = shape[i][0] * shape[i][1] expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_quantize_grouped_mxfp8(self) -> None: + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: """Test grouped quantization for MXFP8 against per-tensor quantization.""" # Test wont pass until the grouped quantization PR from Oleg is merged. num_tensors = 2 - shapes = [(512, 1024) for _ in range(num_tensors)] + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a 2D tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + quantized_tensors = [ + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors + ] + grouped_input = torch.cat(input_tensors, dim=0) - # Create BF16 input tensors and pack into a grouped tensor - input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shapes] + # Create MXFP8 output grouped tensor (rowwise only for easier validation) quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) - quantizer.optimize_for_gemm = True - grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shapes=shapes, - quantizer=None, + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, device="cuda", - dtype=torch.bfloat16, ) - offset = 0 - for tensor in input_tensors: - numel = tensor.numel() - grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) - offset += numel - - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shapes=shapes, - quantizer=quantizer, - device="cuda", + # Quantize using grouped API + grouped_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, ) - # Quantize using grouped API (handle both 2-arg and 3-arg bindings) - _ = tex.group_quantize(grouped_input, grouped_output) # Build expected output by quantizing each tensor independently expected_data = [] expected_scale_inv = [] @@ -456,11 +457,11 @@ def test_group_quantize_cudagraph_capturable(self) -> None: def test_clear(self) -> None: """Test clear method""" num_tensors = 3 - shapes = [(256, 512) for _ in range(num_tensors)] + shape = [(256, 512) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shapes=shapes, + shape=shape, quantizer=None, device="cuda", dtype=torch.float32, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index c002067e11..bf5792ffc9 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -8,8 +8,7 @@ import math import torch -import transformer_engine -import transformer_engine_torch as tex + from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor @@ -53,7 +52,7 @@ class GroupedTensor: def __init__( self, num_tensors: int, - shapes: List[Tuple[int, int]], + shape: Optional[List[Tuple[int, int]]] = None, quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, @@ -76,7 +75,7 @@ def __init__( Args: num_tensors: Number of tensors in the group - shapes: 2D shape of each tensor (len num_tensors) + shape: 2D shape of each tensor (len num_tensors) quantizer: Quantizer for the grouped tensor data: Row-wise data buffer (1D flattened) columnwise_data: Column-wise data buffer (1D flattened) @@ -93,7 +92,7 @@ def __init__( """ self.num_tensors = num_tensors self.quantizer = quantizer - self.shapes = shapes + self.shape = shape self.dtype = ( dtype if dtype is not None else torch.float32 ) # Default to float32 if not provided @@ -269,7 +268,7 @@ def __repr__(self) -> str: """String representation of the GroupedTensor.""" return ( f"GroupedTensor(num_tensors={self.num_tensors}, " - f"shapes={self.shapes}, " + f"shape={self.shape}, " f"logical_shape={self.logical_shape}, " f"dtype={self.get_dtype()})" ) @@ -295,7 +294,7 @@ def __str__(self) -> str: @staticmethod def make_grouped_tensor_with_shapes( num_tensors: int, - shapes: List[Tuple[int, int]], + shape: List[Tuple[int, int]], quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -305,8 +304,8 @@ def make_grouped_tensor_with_shapes( Args: num_tensors: Number of tensors - shapes: 2D shape of each tensor (len num_tensors) - quantizer: Quantizer for the grouped tensor + shape: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for each tensor device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -315,16 +314,16 @@ def make_grouped_tensor_with_shapes( """ # First dim - first_dim_list = [s[0] for s in shapes] + first_dim_list = [s[0] for s in shape] uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) logical_first_dim = sum(first_dim_list) if uniform_first_dim: first_dims = None else: - first_dims = torch.tensor([s[0] for s in shapes], dtype=torch.int64, device=device) + first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) # Last dim - last_dim_list = [s[1] for s in shapes] + last_dim_list = [s[1] for s in shape] logical_last_dim = last_dim_list[0] assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" @@ -359,7 +358,7 @@ def make_grouped_tensor( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) logical_first_dim: Logical first dimension logical_last_dim: Logical last dimension - quantizer: Quantizer for the grouped tensor + quantizer: Quantizer for each tensor Used to figure out the recipe and what to allocate. device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -388,7 +387,7 @@ def make_grouped_tensor( # Calculate tensor offsets (cumulative element offsets) tensor_offsets = None offsets = None - shapes = [] + shape = [] if not all_same_first: # Need explicit offsets for non-uniform shapes # Offsets are based on number of elements and not pointers. @@ -404,14 +403,14 @@ def make_grouped_tensor( offsets = tensor_offsets.tolist() first_dims_list = first_dims.tolist() for i in range(num_tensors): - shapes.append((first_dims_list[i], logical_last_dim)) + shape.append((first_dims_list[i], logical_last_dim)) else: offsets = [ i * logical_first_dim * logical_last_dim // num_tensors for i in range(num_tensors + 1) ] for i in range(num_tensors): - shapes.append((logical_first_dim // num_tensors, logical_last_dim)) + shape.append((logical_first_dim // num_tensors, logical_last_dim)) # Calculate logical shape based logical_shape = (logical_first_dim, logical_last_dim) @@ -450,7 +449,7 @@ def make_grouped_tensor( # For grouped tensors, we need to calculate scale_inv size for all tensors total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) scale_elements = math.prod(scale_inv_shape) total_scale_elements += scale_elements @@ -463,7 +462,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements @@ -499,7 +498,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) @@ -515,7 +514,7 @@ def make_grouped_tensor( # Columnwise scale inverse buffer total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -532,7 +531,7 @@ def make_grouped_tensor( # For simplicity, calculate total scale elements needed total_scale_elements = 0 scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) @@ -544,7 +543,7 @@ def make_grouped_tensor( # Columnwise scale inverse total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] - for i, s in enumerate(shapes): + for i, s in enumerate(shape): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -577,7 +576,7 @@ def make_grouped_tensor( grouped_tensor = GroupedTensor( num_tensors=num_tensors, - shapes=shapes, + shape=shape, dtype=dtype, quantizer=quantizer, data=data, @@ -646,7 +645,7 @@ def split_into_quantized_tensors( if no_quantization: for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shapes[i] + tensor_shape = self.shape[i] # Get tensor data slice if self.offsets is not None: @@ -700,7 +699,7 @@ def split_into_quantized_tensors( for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shapes[i] + tensor_shape = self.shape[i] numel = tensor_shape[0] * tensor_shape[1] # Get data offsets @@ -933,32 +932,18 @@ def create_and_quantize( Quantize given tensors into quantized tensors with underlying storage allocated in a GroupedTensor. """ - grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=len(tensors), - shapes=[t.shape for t in tensors], - quantizer=None, - device=device, - dtype=tensors[0].dtype, - ) - - offset = 0 - for tensor in tensors: - numel = tensor.numel() - grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) - offset += numel - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=len(tensors), - shapes=[t.shape for t in tensors], + shape=[t.shape for t in tensors], quantizer=quantizer, device=device, dtype=dtype, ) - _ = tex.group_quantize(grouped_input, grouped_output) - grouped_output.quantized_tensors = grouped_output.split_into_quantized_tensors() + grouped_tensor.quantize(tensors, noop_flag=noop_flag) - return grouped_output + return grouped_tensor def quantize( self, @@ -968,9 +953,8 @@ def quantize( """ Quantize the GroupedTensor inplace. """ - self.quantized_tensors = self.split_into_quantized_tensors() + + quantized_tensors = self.split_into_quantized_tensors() for i in range(self.num_tensors): - self.quantizer.update_quantized( - tensors[i], self.quantized_tensors[i], noop_flag=noop_flag - ) - return self.quantized_tensors + self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + return quantized_tensors From 5b776ec2489a69cb0db49fa8275010a97b5fa019 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:39:19 -0800 Subject: [PATCH 47/59] minor tweaks to return to main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/cast/cast.cu | 1 + transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 1 + .../common/include/transformer_engine/transformer_engine.h | 1 + 3 files changed, 3 insertions(+) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index f4825970cb..57404ae8a5 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,6 +30,7 @@ void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); using namespace transformer_engine; + constexpr bool IS_ACT = false; dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); } diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index f454209409..6447fc4542 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -787,6 +787,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations using namespace group_quantize_kernel; checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); const bool use_rowwise_scaling = output->has_data(); const bool use_colwise_scaling = output->has_columnwise_data(); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 635b9fdcce..e316f8be8c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -967,6 +967,7 @@ class TensorWrapper { /*! \struct GroupedTensorWrapper * \brief C++ wrapper for the NVTEGroupedTensor class. */ + class GroupedTensorWrapper { public: /*! \brief Constructs new GroupedTensorWrapper. From 4eee2bcceb1ff9b06b11aa2c6b0f67f9ad54bf20 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:13:20 -0800 Subject: [PATCH 48/59] remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 28 ++--- .../common/fused_attn/fused_attn_fp8.cu | 104 +----------------- .../dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 76 ------------- .../attention/dot_product_attention/utils.py | 4 - 5 files changed, 16 insertions(+), 198 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 47abf1ebc6..f39ed547cb 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1806,21 +1806,21 @@ def get_model(dtype, config): # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), "fp8_10": ModelConfig( - 2, 2048, 24, 192, head_dim_v=128, num_gqa_groups=12, window_size=(512, 512) - ), - "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), - "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + 2, 2048, 24, 192, head_dim_v=128, #num_gqa_groups=12, window_size=(512, 512) + ), + # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + # "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + # "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + # "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + # "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), + # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } -param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] +param_types_fp8_vs_f16 = [torch.bfloat16] #[torch.float16, torch.bfloat16] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] @@ -2054,7 +2054,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - + print(f"type(out_grad): {type(out_grad)} {out_grad.shape}") with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 48c3975264..9796e39ddc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2301,43 +2301,6 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); - printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], - q_t_stride[3]); - printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], - k_t_stride[3]); - printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], - dO_t_stride[3]); - printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, - cudnn_frontend::DataType_t::BFLOAT16); - printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, - cudnn_frontend::DataType_t::FP8_E5M2); - printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("b: %d\n", b); - printf("h: %d\n", h); - printf("hg: %d\n", hg); - printf("s_q: %d\n", s_q); - printf("s_kv: %d\n", s_kv); - printf("d_qk: %d\n", d_qk); - printf("d_v: %d\n", d_v); - printf("is_delayed_scaling: %d\n", is_delayed_scaling); - printf("is_current_scaling: %d\n", is_current_scaling); - printf("is_O_in_F16: %d\n", is_O_in_F16); - printf("is_mxfp8: %d\n", is_mxfp8); - printf("is_causal: %d\n", is_causal); - printf("is_padding: %d\n", is_padding); - printf("is_dropout: %d\n", is_dropout); - printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2360,18 +2323,6 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - printf("s_q_padded: %d\n", padded.s_q_padded); - printf("s_kv_padded: %d\n", padded.s_kv_padded); - printf("s_q_scale: %d\n", padded.s_q_scale); - printf("s_kv_scale: %d\n", padded.s_kv_scale); - printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); - printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); - printf("d_qk_padded: %d\n", padded.d_qk_padded); - printf("d_v_padded: %d\n", padded.d_v_padded); - printf("d_qk_scale: %d\n", padded.d_qk_scale); - printf("d_v_scale: %d\n", padded.d_v_scale); - printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); - printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2393,20 +2344,6 @@ void fused_attn_fp8_bwd_impl_v1( dO_scale_strides.data(), d_out_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); - printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], - q_scale_strides[2], q_scale_strides[3]); - printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], - q_t_scale_strides[2], q_t_scale_strides[3]); - printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], - k_scale_strides[2], k_scale_strides[3]); - printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], - k_t_scale_strides[2], k_t_scale_strides[3]); - printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], - v_scale_strides[2], v_scale_strides[3]); - printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], - dO_scale_strides[2], dO_scale_strides[3]); - printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], - dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2562,9 +2499,6 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); - printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); - printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) .set_stride(dq_stride) @@ -2694,7 +2628,7 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[scale_dP] = devPtrScaledP; variant_pack[amax_dP] = devPtrAmaxdP; } - if (is_current_scaling && !is_O_in_F16) { + if (is_delayed_scaling || (is_current_scaling && !is_O_in_F16)) { variant_pack[descale_o] = devPtrDescaleO; } if (is_delayed_scaling) { @@ -2712,42 +2646,6 @@ void fused_attn_fp8_bwd_impl_v1( // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } - int64_t modulo = 16; - printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); - printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); - printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); - printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); - printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); - printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); - printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, - is_aligned_modulo(devPtrDescaleQ, modulo)); - printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, - is_aligned_modulo(devPtrDescaleK, modulo)); - printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, - is_aligned_modulo(devPtrDescaleV, modulo)); - printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, - is_aligned_modulo(devPtrDescaledO, modulo)); - printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, - is_aligned_modulo(devPtrDescaledO_t, modulo)); - printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); - printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); - printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); - printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, - is_aligned_modulo(devPtrAmaxdQ, modulo)); - printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, - is_aligned_modulo(devPtrAmaxdK, modulo)); - printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, - is_aligned_modulo(devPtrAmaxdV, modulo)); - printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); - printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); - printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, - is_aligned_modulo(devPtrdO_f16, modulo)); - printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); - printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, - is_aligned_modulo(devPtrDescaleQ_t, modulo)); - printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, - is_aligned_modulo(devPtrDescaleK_t, modulo)); - /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index e3aacbf2e6..2aecd032c9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -36,7 +36,7 @@ restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorStorage from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index de5563b7a6..653fd8cfb0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1430,10 +1430,6 @@ def forward( q_fp8, k_fp8, v_fp8 = (None, None, None) # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print( - f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}," - f" is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}" - ) if fp8 and is_input_fp8: QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v @@ -1628,7 +1624,6 @@ def forward( out = None o_format = qkv_format for i in range(cp_size + 1): - print(f">>>>>>>>>>>> {torch.cuda.current_device()}: i: {i}, cp_size: {cp_size}") if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): # wait until KV is received @@ -1917,10 +1912,6 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - print( - f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}," - f" out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}" - ) if enable_mla: out = out.view(o_shape) else: @@ -1969,12 +1960,7 @@ def forward( elif o_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] - print(f"========= {torch.cuda.current_device()}: out.shape: {out.shape} {out.dtype}") out_part = out.to(fwd_nominal_dtype) - print( - f"========= {torch.cuda.current_device()}: out_part.shape:" - f" {out_part.shape} {out_part.dtype}" - ) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) @@ -2077,15 +2063,6 @@ def forward( q_f16 = q_f16.view(q.shape) kv_f16 = kv f16_tensors = (q_f16, kv_f16, out_f16) - if torch.cuda.current_device() == 0: - print( - "fp8_tensors:" - f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" - ) - print( - "f16_tensors:" - f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" - ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -2367,10 +2344,6 @@ def backward(ctx, dout, *_args): # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: - print( - f"========= {torch.cuda.current_device()}: before a2a: out.shape:" - f" {out.shape} {out.dtype} dout.shape: {dout.shape} {dout.dtype}" - ) if not ctx.use_fused_attention: # out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) @@ -2386,10 +2359,6 @@ def backward(ctx, dout, *_args): ctx.cp_stream, True, ) - print( - f"========= {torch.cuda.current_device()}: after a2a: dout.shape:" - f" {dout.shape} {dout.dtype} {q.shape} {ctx.v_shape}" - ) if ctx.enable_mla: out = out.view(*ctx.o_shape) @@ -3272,15 +3241,6 @@ def forward( f16_tensors = (q, k, v, out) else: f16_tensors = (q, k, v, out) - if torch.cuda.current_device() == 0: - print( - "fp8_tensors:" - f" {[x.shape if x is not None else None for x in fp8_tensors]} {[type(x) for x in fp8_tensors]}" - ) - print( - "f16_tensors:" - f" {[x.shape if x is not None else None for x in f16_tensors]} {[type(x) for x in f16_tensors]}" - ) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, @@ -3385,8 +3345,6 @@ def backward(ctx, dout, *_args): if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] - if torch.cuda.current_device() == 0: - print(f"ctx.q_shape: {ctx.q_shape} {ctx.k_shape} {ctx.v_shape}") dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) dk = torch.zeros( (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), @@ -3398,10 +3356,6 @@ def backward(ctx, dout, *_args): dtype=ctx.fwd_nominal_dtype, device=v.device, ) - if torch.cuda.current_device() == 0: - print(f"dq: {dq.shape} {dq.dtype} {dq.device}") - print(f"dk: {dk.shape} {dk.dtype} {dk.device}") - print(f"dv: {dv.shape} {dv.dtype} {dv.device}") dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3512,23 +3466,11 @@ def backward(ctx, dout, *_args): q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer ) - print( - "aux_ctx_tensors:" - f" {len(aux_ctx_tensors)} {[x.shape if x is not None else None for x in aux_ctx_tensors]} {[type(x) for x in aux_ctx_tensors]}" - ) dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( d_out_format, dout_part ) aux_ctx_tensors.append(dout_part) dout_part = ctx.dO_quantizer(dout_part) - print( - "q_part type:" - f" {type(q_part)} {type(k_part)} {type(v_part)} {type(out_part)} {type(dout_part)}" - ) - print( - "q_part shape:" - f" {q_part.shape} {k_part.shape} {v_part.shape} {out_part.shape} {dout_part.shape}" - ) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, @@ -3606,11 +3548,6 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if torch.cuda.current_device() == 0: - print( - f"dq.shape: {dq.shape} dq_per_step[i - 1].shape:" - f" {dq_per_step[i - 1].shape}" - ) if ctx.qkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": @@ -4431,19 +4368,6 @@ def attn_forward_func_with_cp( in Megatron-LM. """ - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_comm_type=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {qkv_format=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {deterministic=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {use_fused_attention=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_meta=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_group=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_global_ranks=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {cp_stream=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {quantizers=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {pad_between_seqs=}") - print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {fp8_output=}") - # print(f"rank {torch.distributed.get_rank()} attn_forward_func_with_cp: {layer_number=}") if cp_comm_type == "a2a+p2p": assert ( isinstance(cp_group, list) and len(cp_group) == 2 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9a8d38547e..2f9929bffc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2244,10 +2244,6 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): original_shapes = [x.shape for x in [q, k, v]] s_q, d_qk = q.shape[-2:] s_kv, d_v = v.shape[-2:] - print( - f">>>>>>>>>>>> {torch.cuda.current_device()}: {qkv_layout} s_q: {s_q}, d_qk: {d_qk}," - f" s_kv: {s_kv}, d_v: {d_v}, {q.shape}, {k.shape}, {v.shape}" - ) assert s_q % 128 == 0 assert s_kv % 128 == 0 assert d_qk % 32 == 0 From 8500121daf3d3668b59d1eec242460d6d6e0f631 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:13:41 -0800 Subject: [PATCH 49/59] fix combine_and_quantize for f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 2f9929bffc..b572f087b2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2304,7 +2304,7 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): for x in [q_data, k_data, v_data] ] - return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout def combine_and_dequantize( From 0c2c4668f394fd3179fece9f4fef47d6efcef13e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:14:32 +0000 Subject: [PATCH 50/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f39ed547cb..1173d91a16 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1806,7 +1806,11 @@ def get_model(dtype, config): # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), "fp8_10": ModelConfig( - 2, 2048, 24, 192, head_dim_v=128, #num_gqa_groups=12, window_size=(512, 512) + 2, + 2048, + 24, + 192, + head_dim_v=128, # num_gqa_groups=12, window_size=(512, 512) ), # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), @@ -1820,7 +1824,7 @@ def get_model(dtype, config): # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } -param_types_fp8_vs_f16 = [torch.bfloat16] #[torch.float16, torch.bfloat16] +param_types_fp8_vs_f16 = [torch.bfloat16] # [torch.float16, torch.bfloat16] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] From 6744aeeb90a85900eebc335ba0fb77370cf9c35c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:01:58 -0800 Subject: [PATCH 51/59] minor tweaks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/backends.py | 17 ++++++++--------- .../dot_product_attention/context_parallel.py | 2 +- .../attention/dot_product_attention/utils.py | 4 ++++ .../pytorch/attention/multi_head_attention.py | 10 ++++++---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2aecd032c9..95085a0fca 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1339,14 +1339,15 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ + bwd_requires_o_f16 = is_training and (not is_bwd_fp8 or ( + is_bwd_fp8 and ((fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8()))) + bwd_requires_o_fp8 = is_training and is_bwd_fp8 and ( + fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)) if isinstance(out_, QuantizedTensorStorage): - if not is_output_fp8 or not is_bwd_fp8: + if not is_output_fp8 or bwd_requires_o_f16: out = out_.dequantize().view(out_.shape) else: - if is_output_fp8 or ( - is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): + if is_output_fp8 or bwd_requires_o_fp8: out_fp8 = O_quantizer(out_) # print quantizers @@ -1368,12 +1369,10 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or isinstance( - QKV_quantizer, MXFP8Quantizer - ): + if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) - else: + elif fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: if is_input_fp8: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 653fd8cfb0..cf8986c7f6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4224,7 +4224,7 @@ def backward(ctx, dout, *_args): if ( ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() ) and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index b572f087b2..20ae4d135d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2258,6 +2258,10 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): ) q_fp8, k_fp8, v_fp8 = grouped_tensor.quantized_tensors else: + # grouped_tensor = GroupedTensor.create_and_quantize( + # tensors=[q, k], quantizer=qkv_quantizer + # ) + # q_fp8, k_fp8 = grouped_tensor.quantized_tensors q_fp8 = qkv_quantizer(q) k_fp8 = qkv_quantizer(k) v_fp8 = qkv_quantizer(v) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..801c2f525b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -784,14 +784,16 @@ def forward( fp8_dpa = fp8_recipe.fp8_dpa fp8_mha = fp8_recipe.fp8_mha float8_current_scaling = fp8_recipe.float8_current_scaling() + mxfp8_scaling = fp8_recipe.mxfp8() else: fp8_dpa = _dpa_fp8_recipe_dpa fp8_mha = _dpa_fp8_recipe_mha float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" - # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling - # DPA: always produce FP8 output when fp8=True to take advantage of the O amax - dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling or MXFP8BlockScaling recipe + qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling and not mxfp8_scaling + # DPA: produce FP8 output when fp8=True to take advantage of the O amax except for MXFP8BlockScaling + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling From 4cec878b6e6eb40ff9346db658aeca125a6411d8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:02:17 -0800 Subject: [PATCH 52/59] tweak tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 1173d91a16..92b8ade67e 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,13 +1804,14 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), # , attn_mask_type="causal"), + "fp8_9": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), "fp8_10": ModelConfig( 2, - 2048, - 24, + 4096, + 128, 192, - head_dim_v=128, # num_gqa_groups=12, window_size=(512, 512) + head_dim_v=128, + attn_mask_type="causal", ), # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), @@ -1871,7 +1872,7 @@ def test_mha_fp8_vs_f16( ) elif scaling_mode == "mxfp8": fp8_recipe = recipe.MXFP8BlockScaling( - fp8_format=recipe.Format.HYBRID, + fp8_format=recipe.Format.E4M3, fp8_dpa=True, fp8_mha=True, ) @@ -2058,7 +2059,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - print(f"type(out_grad): {type(out_grad)} {out_grad.shape}") with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, @@ -2128,7 +2128,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ) elif scaling_mode == "mxfp8": fp8_recipe = recipe.MXFP8BlockScaling( - fp8_format=recipe.Format.HYBRID, + fp8_format=recipe.Format.E4M3, fp8_dpa=True, fp8_mha=False, ) @@ -2401,7 +2401,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) From 5c8e939ab2ea123941539573d1b06eee01d8aa94 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 23:03:11 +0000 Subject: [PATCH 53/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 4 ++- .../dot_product_attention/backends.py | 30 +++++++++++++++---- .../dot_product_attention/context_parallel.py | 4 ++- .../pytorch/attention/multi_head_attention.py | 8 ++++- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 92b8ade67e..7ae73a753a 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1804,7 +1804,9 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), + "fp8_9": ModelConfig( + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), "fp8_10": ModelConfig( 2, 4096, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 95085a0fca..906f3ade45 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1339,10 +1339,24 @@ def forward( # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - bwd_requires_o_f16 = is_training and (not is_bwd_fp8 or ( - is_bwd_fp8 and ((fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8()))) - bwd_requires_o_fp8 = is_training and is_bwd_fp8 and ( - fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)) + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) if isinstance(out_, QuantizedTensorStorage): if not is_output_fp8 or bwd_requires_o_f16: out = out_.dequantize().view(out_.shape) @@ -1369,10 +1383,14 @@ def forward( fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) or fp8_recipe.mxfp8(): + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) - elif fp8_recipe.delayed() or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16): + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: if is_input_fp8: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index cf8986c7f6..56c36aef8a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4224,7 +4224,9 @@ def backward(ctx, dout, *_args): if ( ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() ) and ctx.is_input_fp8: - dq, dk, dv, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize( + ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer + ) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 801c2f525b..0a276bdc8a 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -791,7 +791,13 @@ def forward( float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling or MXFP8BlockScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling and not mxfp8_scaling + qkv_fp8_output = ( + fp8 + and fp8_mha + and rotary_pos_emb is None + and not float8_current_scaling + and not mxfp8_scaling + ) # DPA: produce FP8 output when fp8=True to take advantage of the O amax except for MXFP8BlockScaling dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling # Proj Gemm: match DPA output except for Float8CurrentScaling From 7b6b364499701a5bb4c7e56f0a96aadb9315fa09 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:56:03 -0800 Subject: [PATCH 54/59] fix ds descale_o Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 9796e39ddc..f3557eeb68 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2301,6 +2301,43 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); + printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], + q_t_stride[3]); + printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], + k_t_stride[3]); + printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], + dO_t_stride[3]); + printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, + cudnn_frontend::DataType_t::BFLOAT16); + printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, + cudnn_frontend::DataType_t::FP8_E5M2); + printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, + cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); + printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, + NVTE_QKV_Format::NVTE_BHSD); + printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, + NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); + printf("b: %d\n", b); + printf("h: %d\n", h); + printf("hg: %d\n", hg); + printf("s_q: %d\n", s_q); + printf("s_kv: %d\n", s_kv); + printf("d_qk: %d\n", d_qk); + printf("d_v: %d\n", d_v); + printf("is_delayed_scaling: %d\n", is_delayed_scaling); + printf("is_current_scaling: %d\n", is_current_scaling); + printf("is_O_in_F16: %d\n", is_O_in_F16); + printf("is_mxfp8: %d\n", is_mxfp8); + printf("is_causal: %d\n", is_causal); + printf("is_padding: %d\n", is_padding); + printf("is_dropout: %d\n", is_dropout); + printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2323,6 +2360,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + printf("s_q_padded: %d\n", padded.s_q_padded); + printf("s_kv_padded: %d\n", padded.s_kv_padded); + printf("s_q_scale: %d\n", padded.s_q_scale); + printf("s_kv_scale: %d\n", padded.s_kv_scale); + printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); + printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); + printf("d_qk_padded: %d\n", padded.d_qk_padded); + printf("d_v_padded: %d\n", padded.d_v_padded); + printf("d_qk_scale: %d\n", padded.d_qk_scale); + printf("d_v_scale: %d\n", padded.d_v_scale); + printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); + printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2344,6 +2393,20 @@ void fused_attn_fp8_bwd_impl_v1( dO_scale_strides.data(), d_out_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); + printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], + q_scale_strides[2], q_scale_strides[3]); + printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], + q_t_scale_strides[2], q_t_scale_strides[3]); + printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], + k_scale_strides[2], k_scale_strides[3]); + printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], + k_t_scale_strides[2], k_t_scale_strides[3]); + printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], + v_scale_strides[2], v_scale_strides[3]); + printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], + dO_scale_strides[2], dO_scale_strides[3]); + printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], + dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2499,6 +2562,9 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); + printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); + printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) .set_stride(dq_stride) @@ -2646,6 +2712,42 @@ void fused_attn_fp8_bwd_impl_v1( // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } + int64_t modulo = 16; + printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); + printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); + printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); + printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); + printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); + printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); + printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, + is_aligned_modulo(devPtrDescaleQ, modulo)); + printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, + is_aligned_modulo(devPtrDescaleK, modulo)); + printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, + is_aligned_modulo(devPtrDescaleV, modulo)); + printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, + is_aligned_modulo(devPtrDescaledO, modulo)); + printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, + is_aligned_modulo(devPtrDescaledO_t, modulo)); + printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); + printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); + printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); + printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, + is_aligned_modulo(devPtrAmaxdQ, modulo)); + printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, + is_aligned_modulo(devPtrAmaxdK, modulo)); + printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, + is_aligned_modulo(devPtrAmaxdV, modulo)); + printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); + printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); + printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, + is_aligned_modulo(devPtrdO_f16, modulo)); + printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); + printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, + is_aligned_modulo(devPtrDescaleQ_t, modulo)); + printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, + is_aligned_modulo(devPtrDescaleK_t, modulo)); + /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { From 462eb4f5ce7da5cae5c8d11d32d334a642242ecc Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:58:32 -0800 Subject: [PATCH 55/59] Revert "fix ds descale_o" This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_fp8.cu | 102 ------------------ 1 file changed, 102 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f3557eeb68..9796e39ddc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2301,43 +2301,6 @@ void fused_attn_fp8_bwd_impl_v1( generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_stride.data(), q_format, false); generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_stride.data(), kv_format, false); generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_stride.data(), d_out_format, false); - printf("q_t_stride: %d, %d, %d, %d\n", q_t_stride[0], q_t_stride[1], q_t_stride[2], - q_t_stride[3]); - printf("k_t_stride: %d, %d, %d, %d\n", k_t_stride[0], k_t_stride[1], k_t_stride[2], - k_t_stride[3]); - printf("dO_t_stride: %d, %d, %d, %d\n", dO_t_stride[0], dO_t_stride[1], dO_t_stride[2], - dO_t_stride[3]); - printf("qkv_tensor_type: %d, %d, %d\n", qkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("o_tensor_type: %d, %d, %d\n", o_tensor_type, cudnn_frontend::DataType_t::HALF, - cudnn_frontend::DataType_t::BFLOAT16); - printf("do_tensor_type: %d, %d, %d\n", do_tensor_type, cudnn_frontend::DataType_t::FP8_E4M3, - cudnn_frontend::DataType_t::FP8_E5M2); - printf("dqkv_tensor_type: %d, %d, %d\n", dqkv_tensor_type, - cudnn_frontend::DataType_t::FP8_E4M3, cudnn_frontend::DataType_t::FP8_E5M2); - printf("qkv_layout: %d, %d, %d\n", qkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("o_format: %d, %d, %d\n", o_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("d_out_format: %d, %d, %d\n", d_out_format, NVTE_QKV_Format::NVTE_BSHD, - NVTE_QKV_Format::NVTE_BHSD); - printf("dqkv_layout: %d, %d, %d\n", dqkv_layout, NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, - NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); - printf("b: %d\n", b); - printf("h: %d\n", h); - printf("hg: %d\n", hg); - printf("s_q: %d\n", s_q); - printf("s_kv: %d\n", s_kv); - printf("d_qk: %d\n", d_qk); - printf("d_v: %d\n", d_v); - printf("is_delayed_scaling: %d\n", is_delayed_scaling); - printf("is_current_scaling: %d\n", is_current_scaling); - printf("is_O_in_F16: %d\n", is_O_in_F16); - printf("is_mxfp8: %d\n", is_mxfp8); - printf("is_causal: %d\n", is_causal); - printf("is_padding: %d\n", is_padding); - printf("is_dropout: %d\n", is_dropout); - printf("is_bias: %d\n", is_bias); Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q_t") .set_dim({b, h, s_q, d_qk}) @@ -2360,18 +2323,6 @@ void fused_attn_fp8_bwd_impl_v1( .set_data_type(o_tensor_type)); // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); - printf("s_q_padded: %d\n", padded.s_q_padded); - printf("s_kv_padded: %d\n", padded.s_kv_padded); - printf("s_q_scale: %d\n", padded.s_q_scale); - printf("s_kv_scale: %d\n", padded.s_kv_scale); - printf("s_q_scale_padded: %d\n", padded.s_q_scale_padded); - printf("s_kv_scale_padded: %d\n", padded.s_kv_scale_padded); - printf("d_qk_padded: %d\n", padded.d_qk_padded); - printf("d_v_padded: %d\n", padded.d_v_padded); - printf("d_qk_scale: %d\n", padded.d_qk_scale); - printf("d_v_scale: %d\n", padded.d_v_scale); - printf("d_qk_scale_padded: %d\n", padded.d_qk_scale_padded); - printf("d_v_scale_padded: %d\n", padded.d_v_scale_padded); std::vector q_scale_strides(4); std::vector q_t_scale_strides(4); std::vector k_scale_strides(4); @@ -2393,20 +2344,6 @@ void fused_attn_fp8_bwd_impl_v1( dO_scale_strides.data(), d_out_format, false); generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, dO_t_scale_strides.data(), d_out_format, false); - printf("q_scale_strides: %d, %d, %d, %d\n", q_scale_strides[0], q_scale_strides[1], - q_scale_strides[2], q_scale_strides[3]); - printf("q_t_scale_strides: %d, %d, %d, %d\n", q_t_scale_strides[0], q_t_scale_strides[1], - q_t_scale_strides[2], q_t_scale_strides[3]); - printf("k_scale_strides: %d, %d, %d, %d\n", k_scale_strides[0], k_scale_strides[1], - k_scale_strides[2], k_scale_strides[3]); - printf("k_t_scale_strides: %d, %d, %d, %d\n", k_t_scale_strides[0], k_t_scale_strides[1], - k_t_scale_strides[2], k_t_scale_strides[3]); - printf("v_scale_strides: %d, %d, %d, %d\n", v_scale_strides[0], v_scale_strides[1], - v_scale_strides[2], v_scale_strides[3]); - printf("dO_scale_strides: %d, %d, %d, %d\n", dO_scale_strides[0], dO_scale_strides[1], - dO_scale_strides[2], dO_scale_strides[3]); - printf("dO_t_scale_strides: %d, %d, %d, %d\n", dO_t_scale_strides[0], dO_t_scale_strides[1], - dO_t_scale_strides[2], dO_t_scale_strides[3]); descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Descale_q") @@ -2562,9 +2499,6 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides_v1(b, hg, hg, s_q, s_kv, d_qk, d_v, dv_stride.data(), dqkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - printf("dq_stride: %d, %d, %d, %d\n", dq_stride[0], dq_stride[1], dq_stride[2], dq_stride[3]); - printf("dk_stride: %d, %d, %d, %d\n", dk_stride[0], dk_stride[1], dk_stride[2], dk_stride[3]); - printf("dv_stride: %d, %d, %d, %d\n", dv_stride[0], dv_stride[1], dv_stride[2], dv_stride[3]); dQ->set_output(true) .set_dim({b, h, s_q, d_qk}) .set_stride(dq_stride) @@ -2712,42 +2646,6 @@ void fused_attn_fp8_bwd_impl_v1( // variant_pack[descale_dO] = devPtrDescaledO; variant_pack[descale_dO_t] = devPtrDescaledO_t; } - int64_t modulo = 16; - printf("devPtrQ: %p, is_aligned: %d\n", devPtrQ, is_aligned_modulo(devPtrQ, modulo)); - printf("devPtrK: %p, is_aligned: %d\n", devPtrK, is_aligned_modulo(devPtrK, modulo)); - printf("devPtrV: %p, is_aligned: %d\n", devPtrV, is_aligned_modulo(devPtrV, modulo)); - printf("devPtrO: %p, is_aligned: %d\n", devPtrO, is_aligned_modulo(devPtrO, modulo)); - printf("devPtrM: %p, is_aligned: %d\n", devPtrM, is_aligned_modulo(devPtrM, modulo)); - printf("devPtrdO: %p, is_aligned: %d\n", devPtrdO, is_aligned_modulo(devPtrdO, modulo)); - printf("devPtrDescaleQ: %p, is_aligned: %d\n", devPtrDescaleQ, - is_aligned_modulo(devPtrDescaleQ, modulo)); - printf("devPtrDescaleK: %p, is_aligned: %d\n", devPtrDescaleK, - is_aligned_modulo(devPtrDescaleK, modulo)); - printf("devPtrDescaleV: %p, is_aligned: %d\n", devPtrDescaleV, - is_aligned_modulo(devPtrDescaleV, modulo)); - printf("devPtrDescaledO: %p, is_aligned: %d\n", devPtrDescaledO, - is_aligned_modulo(devPtrDescaledO, modulo)); - printf("devPtrDescaledO_t: %p, is_aligned: %d\n", devPtrDescaledO_t, - is_aligned_modulo(devPtrDescaledO_t, modulo)); - printf("devPtrdQ: %p, is_aligned: %d\n", devPtrdQ, is_aligned_modulo(devPtrdQ, modulo)); - printf("devPtrdK: %p, is_aligned: %d\n", devPtrdK, is_aligned_modulo(devPtrdK, modulo)); - printf("devPtrdV: %p, is_aligned: %d\n", devPtrdV, is_aligned_modulo(devPtrdV, modulo)); - printf("devPtrAmaxdQ: %p, is_aligned: %d\n", devPtrAmaxdQ, - is_aligned_modulo(devPtrAmaxdQ, modulo)); - printf("devPtrAmaxdK: %p, is_aligned: %d\n", devPtrAmaxdK, - is_aligned_modulo(devPtrAmaxdK, modulo)); - printf("devPtrAmaxdV: %p, is_aligned: %d\n", devPtrAmaxdV, - is_aligned_modulo(devPtrAmaxdV, modulo)); - printf("devPtrQ_t: %p, is_aligned: %d\n", devPtrQ_t, is_aligned_modulo(devPtrQ_t, modulo)); - printf("devPtrK_t: %p, is_aligned: %d\n", devPtrK_t, is_aligned_modulo(devPtrK_t, modulo)); - printf("devPtrdO_f16: %p, is_aligned: %d\n", devPtrdO_f16, - is_aligned_modulo(devPtrdO_f16, modulo)); - printf("devPtrdO_t: %p, is_aligned: %d\n", devPtrdO_t, is_aligned_modulo(devPtrdO_t, modulo)); - printf("devPtrDescaleQ_t: %p, is_aligned: %d\n", devPtrDescaleQ_t, - is_aligned_modulo(devPtrDescaleQ_t, modulo)); - printf("devPtrDescaleK_t: %p, is_aligned: %d\n", devPtrDescaleK_t, - is_aligned_modulo(devPtrDescaleK_t, modulo)); - /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { From 77995d2d949c667dafe19da1ba9406b2ed7a117e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:43:49 -0800 Subject: [PATCH 56/59] minor fixes for p2p and ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 22 ++++--- .../attention/test_attention_with_cp.py | 64 +++++++++++-------- .../dot_product_attention/context_parallel.py | 44 +++++++++---- 3 files changed, 81 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 5cb43f277a..949dbf3d1c 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -186,13 +186,17 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + deterministic="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) # When is_training is False, gradient outputs are None. is_training = is_training == "True" - + if deterministic == "True": + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" @@ -228,7 +232,6 @@ def run_dpa_with_cp( device_count = torch.cuda.device_count() device = rank % device_count torch.cuda.set_device(device) - print(f"rank: {rank}, world_size: {world_size}") logging.info(f"[Rank {rank}] Setup: world_size {world_size}") dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) @@ -330,7 +333,7 @@ def run_dpa_with_cp( dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: + if fp8_mha and scaling_mode != "mxfp8": q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -377,12 +380,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out, max_logit = out if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: @@ -438,7 +441,7 @@ def run_dpa_with_cp( qkv_quantizer.amax.fill_(0.0) dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) - if fp8_mha: + if fp8_mha and scaling_mode != "mxfp8": q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] @@ -494,12 +497,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out_, max_logit_ = out_ if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: @@ -528,9 +531,10 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: + print(f"========= {torch.cuda.current_device()}: tensors[{i}].shape: {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}") assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index a5fe8f74f5..dc079a7193 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -28,6 +28,12 @@ pytest_logging_level = logging.getLevelName(logging.root.level) +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) @@ -153,9 +159,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), # GQA "cp_2_1": ModelConfig( - 2, 4096, 16, 192, head_dim_v=128 + 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, @@ -219,23 +225,23 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: configs = [ - "cp_1_0", - "cp_1_1", - "cp_1_4", - "cp_1_5", + # "cp_1_0", + # "cp_1_1", + # "cp_1_4", + # "cp_1_5", "cp_2_0", "cp_2_1", - "cp_2_2", - "cp_2_3", - "cp_2_4", - "cp_3_1", - "cp_3_2", - "cp_3_4", - "cp_4_2", + # "cp_2_2", + # "cp_2_3", + # "cp_2_4", + # "cp_3_1", + # "cp_3_2", + # "cp_3_4", + # "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} - dtypes = ["bf16", "fp8"] - qkv_formats = ["bshd", "sbhd", "thd"] + dtypes = ["fp8"] #["bf16", "fp8"] + qkv_formats = ["bshd"]#, "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -247,11 +253,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) -@pytest.mark.parametrize("f16_O", [True, False]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) +@pytest.mark.parametrize("f16_O", [True]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): + # # TODO: Remove this once MXFP8 is supported with fp8_bwd=True! + # if scaling_mode == "mxfp8" and fp8_bwd: + # pytest.skip("MXFP8 only works with fp8_bwd=False!") + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -289,8 +299,8 @@ def test_cp_with_fused_attention( pytest.skip("FP8 attention cannot work with THD format yet!") if dtype == "fp8" and config.attn_bias_type != "no_bias": pytest.skip("FP8 attention cannot work with bias yet!") - if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("FP8 attention cannot work with sliding window yet!") + # if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + # pytest.skip("FP8 attention cannot work with sliding window yet!") if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": @@ -310,14 +320,14 @@ def test_cp_with_fused_attention( pytest.skip("Only fp8 works with scaling_mode != None!") if dtype == "fp8" and scaling_mode is None: pytest.skip("fp8 only works with scaling_mode != None!") - if ( - dtype == "fp8" - and scaling_mode == "current" - and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] - ): - pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") + # if ( + # dtype == "fp8" + # and scaling_mode == "current" + # and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + # ): + # pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") + pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!") # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: @@ -377,6 +387,7 @@ def test_cp_with_fused_attention( fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -396,6 +407,7 @@ def test_cp_with_fused_attention( scaling_mode=scaling_mode, f16_O=f16_O, is_training=is_training, + deterministic=_deterministic, log_level=pytest_logging_level, ), check=True, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 56c36aef8a..7886c625b4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2020,7 +2020,6 @@ def forward( kv_fp8 = None kv = p2p_comm_buffers[-1] - q_fp8, kv_fp8 = None, None if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) @@ -3018,6 +3017,7 @@ def forward( fwd_nominal_dtype = q.dtype fp8_meta_kwargs = {} q_fp8, k_fp8, v_fp8 = (None, None, None) + q_f16, k_f16, v_f16 = (None, None, None) fused_attn_backend = None if fp8: assert use_fused_attention, "FP8 is only supported with Fused Attention!" @@ -3025,9 +3025,12 @@ def forward( if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): + q_f16, k_f16, v_f16 = q, k, v q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, q, k, v, QKV_quantizer ) + else: + q_f16, k_f16, v_f16 = q, k, v if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs["s_quantizer"] = S_quantizer @@ -3149,6 +3152,7 @@ def forward( cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, ) + if fp8: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: @@ -3225,20 +3229,27 @@ def forward( ctx.fp8_recipe = fp8_recipe fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) + ctx.qkv_reshaped = True if ctx.fp8: + q_fp8_save, k_fp8_save, v_fp8_save = None, None, None + if fp8_recipe.delayed() or fp8_recipe.float8_current_scaling(): + q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) + k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) + v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) if fp8_recipe.delayed(): - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) if fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16: - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: - fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) f16_tensors = (None, None, None, out) if fp8_recipe.mxfp8(): f16_tensors = (q, k, v, out) elif fp8: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) - f16_tensors = (q, k, v, out) + f16_tensors = (q_f16, k_f16, v_f16, out) + ctx.qkv_reshaped = False else: f16_tensors = (q, k, v, out) @@ -3276,13 +3287,17 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - if fp8: + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() ctx.O_quantizer = O_quantizer.copy() ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None - ctx.dQKV_quantizer = dQKV_quantizer.copy() - ctx.dO_quantizer = dO_quantizer.copy() - ctx.dP_quantizer = dP_quantizer.copy() if dP_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer.scale = O_quantizer.scale.clone() @@ -3333,17 +3348,18 @@ def backward(ctx, dout, *_args): dout = dout.view(ctx.out_shape) dout_fp8 = None if ctx.fp8: - if ( - ctx.is_output_fp8 - and not isinstance(dout, QuantizedTensorStorage) - and not ctx.fp8_recipe.mxfp8() - ): + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) if not ctx.fp8_recipe.mxfp8(): dout = dout_fp8._data if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if not ctx.qkv_reshaped: + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) dk = torch.zeros( From 586b698bcb513d536709fd47824df92bfc0e185c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Mar 2026 01:44:38 +0000 Subject: [PATCH 57/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/run_attention_with_cp.py | 5 ++++- tests/pytorch/attention/test_attention_with_cp.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 949dbf3d1c..242d6b9e7a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -534,7 +534,10 @@ def run_dpa_with_cp( for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - print(f"========= {torch.cuda.current_device()}: tensors[{i}].shape: {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}") + print( + f"========= {torch.cuda.current_device()}: tensors[{i}].shape:" + f" {tensor.shape} {tensor.dtype} {torch.isnan(tensor).any()} {torch.isinf(tensor).any()}" + ) assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index dc079a7193..10ab2dffe7 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -159,7 +159,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), # GQA + "cp_2_0": ModelConfig( + 2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), # GQA "cp_2_1": ModelConfig( 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA @@ -240,8 +242,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): # "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} - dtypes = ["fp8"] #["bf16", "fp8"] - qkv_formats = ["bshd"]#, "sbhd", "thd"] + dtypes = ["fp8"] # ["bf16", "fp8"] + qkv_formats = ["bshd"] # , "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -327,7 +329,9 @@ def test_cp_with_fused_attention( # ): # pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!") + pytest.skip( + "f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!" + ) # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: # pytest.skip("MLA CP currently only support KV P2P!") # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: From 1e7cd70b4f8a34685cdaa951f8dfe97a7be0ed9b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:42:35 -0800 Subject: [PATCH 58/59] tweak cp test skips Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/test_attention_with_cp.py | 164 ++++++++---------- .../attention/dot_product_attention/utils.py | 59 ++++++- 2 files changed, 121 insertions(+), 102 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 10ab2dffe7..c9cdc6baf8 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -102,25 +102,29 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config.context_parallel = True config.cp_comm_type = cp_comm_type - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No support for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + # FlashAttention / CP implementation specific: MLA only with KV P2P if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} @@ -260,99 +264,67 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): - # # TODO: Remove this once MXFP8 is supported with fp8_bwd=True! - # if scaling_mode == "mxfp8" and fp8_bwd: - # pytest.skip("MXFP8 only works with fp8_bwd=False!") + config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") - if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+!") - if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") - if dtype == "fp8" and get_device_compute_capability() < (9, 0): - pytest.skip("FP8 attention is only supported on sm90+!") + if get_device_compute_capability() < (9, 0) and qkv_format == "thd": + pytest.skip("Only sm90+ architectures support THD format!") + if get_device_compute_capability() < (9, 0) and dtype == "fp8": + pytest.skip("Only sm90+ architectures support FP8 attention!") + + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("dtype=fp8 requires fp8_dpa=True or fp8_mha=True!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: - pytest.skip("Only fp8 works with fp8_bwd=True!") - - config = model_configs_fused_attn[model] - config.context_parallel = True - config.cp_comm_type = cp_comm_type + pytest.skip("fp8_bwd=True requires dtype=fp8!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("dtype!=fp8 requires fp8_dpa=False and fp8_mha=False!") - if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip("THD format does not support post_scale_bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - # if dtype == "fp8" and cp_comm_type == "all_gather": - # pytest.skip( - # "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - # ) if dtype == "fp8" and qkv_format == "thd": - pytest.skip("FP8 attention cannot work with THD format yet!") + pytest.skip("No support for FP8 attention with THD format!") if dtype == "fp8" and config.attn_bias_type != "no_bias": - pytest.skip("FP8 attention cannot work with bias yet!") - # if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - # pytest.skip("FP8 attention cannot work with sliding window yet!") - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): - pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" - ) - if dtype != "fp8" and (fp8_mha or fp8_dpa): - pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") - if dtype == "fp8" and not (fp8_mha or fp8_dpa): - pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") - if dtype != "fp8" and scaling_mode is not None: - pytest.skip("Only fp8 works with scaling_mode != None!") - if dtype == "fp8" and scaling_mode is None: - pytest.skip("fp8 only works with scaling_mode != None!") - # if ( - # dtype == "fp8" - # and scaling_mode == "current" - # and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] - # ): - # pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode not in ["current", "mxfp8"]): + pytest.skip("No support for FP8 attention with bias!") + + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No supprt for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in ["p2p", "a2a+p2p"]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( - "f16_O only needs to be tested for dtype = fp8 and scaling_mode in [current, mxfp8]!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) - # if cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] and config.head_dim_qk != config.head_dim_v: - # pytest.skip("MLA CP currently only support KV P2P!") - # if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - # pytest.skip("MLA CP currently does not support FP8 attention!") - if dtype == "fp8" and config.softmax_type != "vanilla": - pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") + + if config.softmax_type != "vanilla" and dtype == "fp8": + pytest.skip("No support for non-vanilla softmax with FP8 attention!") if config.softmax_type != "vanilla" and cp_comm_type != "a2a": - pytest.skip( - "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" - ) - if ( - get_cudnn_version() < (9, 18, 0) - and config.softmax_type != "vanilla" - and qkv_format == "thd" - ): - pytest.skip( - "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" - " non-vanilla softmax types!" - ) + pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") + if config.softmax_type != "vanilla" and qkv_format == "thd" and get_cudnn_version() < (9, 18, 0): + pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") + + if dtype == "fp8" and scaling_mode is None: + pytest.skip("dtype=fp8 requires scaling_mode != None!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("dtype!=fp8 requires scaling_mode = None!") + if dtype != "fp8" and not f16_O: + pytest.skip("dtype!=fp8 requires f16_O=True!") + if scaling_mode == "delayed" and f16_O: + pytest.skip("scaling_mode=delayed requires f16_O=False!") if scaling_mode == "mxfp8" and not f16_O: - pytest.skip("MXFP8 only works with f16_O=True!") + pytest.skip("scaling_mode=mxfp8 requires f16_O=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 20ae4d135d..69a9ee9e03 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -835,12 +835,59 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - # elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: - # logger.debug( - # "Disabling FusedAttention as it does not support context parallelism with FP8" - # " MLA attention" - # ) - # use_fused_attention = False + elif fp8 and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and THD format" + ) + use_fused_attention = False + elif fp8 and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and bias" + ) + use_fused_attention = False + + elif core_attention_bias_type != "no_bias" and cp_comm_type in [ + "all_gather", + "a2a", + "a2a+p2p", + ]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias" + " and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD" + " format and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif ( + window_size is not None + and (window_size != (-1, 0) or window_size != (-1, -1)) + and cp_comm_type in ["p2p", "a2a+p2p"] + ): + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with sliding" + " window attention and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif cp_comm_type in ["a2a", "a2a+p2p"] and ( + num_heads % 2 != 0 or num_gqa_groups % 2 != 0 + ): + logger.debug( + "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" + " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)", + cp_comm_type, + num_heads, + num_gqa_groups, + ) + use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends From 6d7766a4730f9dab47ede2d41249e2fb6c618fff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Mar 2026 02:44:45 +0000 Subject: [PATCH 59/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/test_attention_with_cp.py | 28 ++++++++++++++----- .../attention/dot_product_attention/utils.py | 4 +-- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index c9cdc6baf8..116d4dcc41 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -110,10 +110,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") - if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type in [ - "p2p", - "a2a+p2p", - ]: + if ( + config.window_size != (-1, 0) + and config.window_size != (-1, -1) + and cp_comm_type + in [ + "p2p", + "a2a+p2p", + ] + ): pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") if cp_comm_type in ["a2a", "a2a+p2p"] and ( @@ -299,10 +304,15 @@ def test_cp_with_fused_attention( if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") - if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in ["p2p", "a2a+p2p"]: + if (config.window_size != (-1, 0) or config.window_size != (-1, -1)) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") - if cp_comm_type in ["a2a", "a2a+p2p"] and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" @@ -312,7 +322,11 @@ def test_cp_with_fused_attention( pytest.skip("No support for non-vanilla softmax with FP8 attention!") if config.softmax_type != "vanilla" and cp_comm_type != "a2a": pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") - if config.softmax_type != "vanilla" and qkv_format == "thd" and get_cudnn_version() < (9, 18, 0): + if ( + config.softmax_type != "vanilla" + and qkv_format == "thd" + and get_cudnn_version() < (9, 18, 0) + ): pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") if dtype == "fp8" and scaling_mode is None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 69a9ee9e03..84f676539b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -877,9 +877,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt cp_comm_type, ) use_fused_attention = False - elif cp_comm_type in ["a2a", "a2a+p2p"] and ( - num_heads % 2 != 0 or num_gqa_groups % 2 != 0 - ): + elif cp_comm_type in ["a2a", "a2a+p2p"] and (num_heads % 2 != 0 or num_gqa_groups % 2 != 0): logger.debug( "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)",