From e841c6b970de4231f568e72a2191de5ddc192fa4 Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 13 Mar 2026 13:06:50 -0700 Subject: [PATCH] [ET-VK][matmul] Re-implement fp32/fp16 matmul and linear with tiled compute and blocked weight packing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace all existing matmul/linear operator implementations with new ones built from the ground up using a tiled compute approach. Delete all legacy implementations (MatMulLegacy.cpp, LinearLegacy.cpp, addmm_optimized.glsl, addmm_naive_*.glsl). New matmul (mm/bmm/addmm): - Single matmul.glsl shader handles mm, bmm, and addmm using FPInputTile, FPWeightTile, FPOutTile infrastructure from SDPA - Adaptive tile size selection (TILE_M=4/2/1) based on GPU occupancy - When mat2 is a constant tensor, automatically routes through the linear path for blocked weight packing New linear: - Custom 4OCĂ—4IC blocked weight prepacking via pack_fp_linear_weight.glsl for optimal cache line utilization during tiled matmul - Supports both transposed [N,K] and non-transposed [K,N] weights with batch dimension support - Separate texture2d weight storage with automatic buffer fallback for large dimensions Performance on Adreno 750 (fp16, vs legacy): - Linear [4096,1024]x[256,1024]: 1.33x faster (texture) - Linear [4096,64]x[128,64]: 2.67x faster (texture) - BMM [1,4096,256]x[1,256,1024]: 1.63x faster (texture) Differential Revision: [D96488384](https://our.internmc.facebook.com/intern/diff/D96488384/) [ghstack-poisoned] --- .../graph/ops/glsl/addmm_naive_buffer.glsl | 86 --- .../graph/ops/glsl/addmm_naive_texture3d.glsl | 189 ------- .../graph/ops/glsl/addmm_naive_texture3d.yaml | 24 - .../graph/ops/glsl/addmm_optimized.glsl | 242 -------- .../graph/ops/glsl/addmm_optimized.yaml | 43 -- .../linear_fp_packed_weight_tile_load.glslh | 75 +++ .../runtime/graph/ops/glsl/linear_scalar.glsl | 106 ++++ .../runtime/graph/ops/glsl/linear_scalar.yaml | 39 ++ .../runtime/graph/ops/glsl/linear_vec.glsl | 107 ++++ .../runtime/graph/ops/glsl/linear_vec.yaml | 41 ++ .../graph/ops/glsl/matmul_fp_bias_apply.glslh | 98 ++++ .../ops/glsl/matmul_fp_mat1_tile_load.glslh | 99 ++++ .../ops/glsl/matmul_fp_mat2_tile_load.glslh | 99 ++++ .../ops/glsl/matmul_fp_out_tile_store.glslh | 119 ++++ .../runtime/graph/ops/glsl/matmul_scalar.glsl | 101 ++++ .../runtime/graph/ops/glsl/matmul_scalar.yaml | 35 ++ .../runtime/graph/ops/glsl/matmul_vec.glsl | 102 ++++ .../runtime/graph/ops/glsl/matmul_vec.yaml | 36 ++ .../graph/ops/glsl/pack_fp_linear_weight.glsl | 125 +++++ ...buffer.yaml => pack_fp_linear_weight.yaml} | 12 +- .../vulkan/runtime/graph/ops/impl/Linear.cpp | 531 ++++++------------ .../vulkan/runtime/graph/ops/impl/Linear.h | 32 ++ .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 325 ----------- .../vulkan/runtime/graph/ops/impl/MatMul.h | 7 +- .../vulkan/runtime/graph/ops/impl/Matmul.cpp | 265 +++++++++ .../test/custom_ops/impl/TestMatmulLinear.cpp | 74 +++ backends/vulkan/test/custom_ops/targets.bzl | 1 + backends/vulkan/test/custom_ops/test_mm.cpp | 465 +++++++++++++++ backends/vulkan/test/op_tests/cases.py | 126 +++-- 29 files changed, 2290 insertions(+), 1314 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.yaml delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_scalar.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_scalar.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_vec.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_vec.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_fp_bias_apply.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat1_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat2_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_fp_out_tile_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_vec.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_vec.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl rename backends/vulkan/runtime/graph/ops/glsl/{addmm_naive_buffer.yaml => pack_fp_linear_weight.yaml} (70%) create mode 100644 backends/vulkan/runtime/graph/ops/impl/Linear.h delete mode 100644 backends/vulkan/runtime/graph/ops/impl/MatMul.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Matmul.cpp create mode 100644 backends/vulkan/test/custom_ops/impl/TestMatmulLinear.cpp create mode 100644 backends/vulkan/test/custom_ops/test_mm.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl deleted file mode 100644 index d845970f692..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -${define_required_extensions("buffer", DTYPE)} - -#define PRECISION ${PRECISION} - -$if HAS_BIAS: - #define HAS_BIAS - -#define T ${buffer_scalar_type(DTYPE)} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")} -$if HAS_BIAS: - ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} -${layout_declare_ubo(B, "ivec4", "out_sizes")} -${layout_declare_ubo(B, "ivec4", "out_strides")} -${layout_declare_ubo(B, "ivec4", "mat1_sizes")} -${layout_declare_ubo(B, "ivec4", "mat1_strides")} -${layout_declare_ubo(B, "ivec4", "mat2_sizes")} -${layout_declare_ubo(B, "ivec4", "mat2_strides")} -${layout_declare_ubo(B, "int", "out_numel")} -$if HAS_BIAS: - ${layout_declare_ubo(B, "float", "alpha", "float", "beta")} - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")} - -void main() { - const ivec4 out_tidx = ivec4( - gl_GlobalInvocationID.x, - gl_GlobalInvocationID.y, - gl_GlobalInvocationID.z % out_sizes.z, - gl_GlobalInvocationID.z / out_sizes.z); - - if (any(greaterThanEqual(out_tidx, out_sizes))) { - return; - } - - int mat1_bufi = tidx_to_bufi( - ivec4(0, out_tidx.y, out_tidx.z, out_tidx.w), mat1_strides); - int mat2_bufi; - if (mat2_is_transposed > 0) { - mat2_bufi = tidx_to_bufi( - ivec4(0, out_tidx.x, 0, 0), mat2_strides); - } else { - mat2_bufi = tidx_to_bufi( - ivec4(out_tidx.x, 0, out_tidx.z, out_tidx.w), mat2_strides); - } - - int mat2_stride; - if (mat2_is_transposed > 0) { - mat2_stride = mat2_strides.x; - } else { - mat2_stride = mat2_strides.y; - } - - T sum = T(0.0); - for (int i = 0; i < mat1_sizes.x; ++i) { - sum += t_mat1[mat1_bufi] * t_mat2[mat2_bufi]; - - mat1_bufi += mat1_strides.x; - mat2_bufi += mat2_stride; - } - - const int out_bufi = tidx_to_bufi(out_tidx, out_strides); -#ifdef HAS_BIAS - t_out[out_bufi] = T(alpha) * T(sum) + T(beta) * t_bias[out_tidx.x]; -#else - t_out[out_bufi] = T(sum); -#endif // HAS_BIAS -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl deleted file mode 100644 index 3d5814eb6d0..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -$if MAT2_IS_TRANSPOSED: - #define MAT2_IS_TRANSPOSED - -$if HAS_BIAS: - #define HAS_BIAS - -${layout_declare_tensor(B, "w", "out_tensor", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "mat1_tensor", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "mat2_tensor", DTYPE, "texture3d")} -$if HAS_BIAS: - ${layout_declare_tensor(B, "r", "bias_tensor", DTYPE, "texture3d")} - -layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 mat1_sizes; - ivec4 mat2_sizes; - ivec3 out_limits; - $if HAS_BIAS: - ivec4 bias_sizes; - float alpha; - float beta; -}; - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int out_packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "mat1_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 mat1_axis_map = unhash_axis_map(mat1_layout); -const lowp int mat1_packed_dim = unhash_packed_dim(mat1_layout); - -${layout_declare_spec_const(C, "int", "mat2_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 mat2_axis_map = unhash_axis_map(mat2_layout); -const lowp int mat2_packed_dim = unhash_packed_dim(mat2_layout); - -$if HAS_BIAS: - ${layout_declare_spec_const(C, "int", "bias_layout", "DEFAULT_LAYOUT")} - const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout); - const lowp int bias_packed_dim = unhash_packed_dim(bias_layout); - -#ifdef HAS_BIAS -vec4 get_bias_texel_W_packed(ivec3 logical_pos) { - ivec3 bias_pos = ivec3(0); - if (bias_sizes.y == 1) { - bias_pos[bias_axis_map.y] = 0; - } else { - bias_pos[bias_axis_map.y] = logical_pos.y; - } - if (bias_sizes.x == 1) { - bias_pos[bias_axis_map.x] = 0; - vec4 bias_texel = texelFetch(bias_tensor, bias_pos, 0); - // Only the first value is valid, the rest is 0 padding - return vec4(bias_texel.x); - } else { - bias_pos[bias_axis_map.x] = logical_pos.x; - } - - return texelFetch(bias_tensor, bias_pos, 0); -} -#endif // HAS_BIAS - -vec4 matmul_naive_k_dim_packed(const ivec3 out_lpos) { - ivec3 mat1_pos; - mat1_pos[mat1_axis_map.x] = 0; - mat1_pos[mat1_axis_map.y] = out_lpos.y; - mat1_pos[mat1_axis_map.z] = out_lpos.z; -#ifdef MAT2_IS_TRANSPOSED - const int mat2_k_axis = mat2_axis_map.x; - const int mat2_row_axis = mat2_axis_map.y; -#else - const int mat2_k_axis = mat2_axis_map.y; - const int mat2_row_axis = mat2_axis_map.x; -#endif // MAT2_IS_TRANSPOSED - - vec4 texel = vec4(0); - const int K = divup4(mat1_sizes.x); - - for (int i = 0; i < K; ++i) { - const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0); - - vec4 sums; - for (int r = 0; r < 4; ++r) { - // On-demand construction of mat2_pos appears to provide the lowest - // latency. Surprisingly, this doesn't translate to mat1_pos. - ivec3 mat2_pos = ivec3(0); - mat2_pos[mat2_k_axis] = i; - mat2_pos[mat2_row_axis] = out_lpos.x * 4 + r; -#ifndef MAT2_IS_TRANSPOSED - mat2_pos[mat2_axis_map.z] = out_lpos.z; -#endif // MAT2_IS_TRANSPOSED - sums[r] = dot(mat1_tex, texelFetch(mat2_tensor, mat2_pos, 0)); - } - - texel += sums; - - mat1_pos[mat1_axis_map.x]++; - } - - return texel; -} - -vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) { - ivec3 mat1_pos; - mat1_pos[mat1_axis_map.x] = 0; - mat1_pos[mat1_axis_map.y] = out_lpos.y; - mat1_pos[mat1_axis_map.z] = out_lpos.z; - - ivec3 mat2_pos; - mat2_pos[mat2_axis_map.x] = out_lpos.x; - mat2_pos[mat2_axis_map.y] = 0; - mat2_pos[mat2_axis_map.z] = out_lpos.z; - - ivec3 mat2_pos_offset = ivec3(0); - mat2_pos_offset[mat2_axis_map.y] = 1; - - const int mat2_y_axis = mat2_axis_map.y; - - vec4 texel = vec4(0); - const int K = divup4(mat1_sizes.x); - - for (int i = 0; - i < K; - ++i, mat1_pos[mat1_axis_map.x]++, mat2_pos[mat2_axis_map.y]+=4) { - const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0); - - for (int r = 0; r < 4; ++r) { - if (4 * i + r >= mat2_sizes.y) { - continue; - } - // On-demand construction of mat2_pos appears to provide the lowest - // latency. Surprisingly, this doesn't translate to mat1_pos. - ivec3 mat2_pos = ivec3(0); - mat2_pos[mat2_axis_map.x] = out_lpos.x; - mat2_pos[mat2_axis_map.y] = 4 * i + r; - mat2_pos[mat2_axis_map.z] = out_lpos.z; - - vec4 mat1_comp_vec = vec4(mat1_tex[r]); - texel = fma(mat1_comp_vec, texelFetch(mat2_tensor, mat2_pos, 0), texel); - } - } - - return texel; -} - -void main() { - const ivec3 out_lpos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_lpos, out_limits))) { - return; - } - - vec4 texel = vec4(0); - -#ifdef MAT2_IS_TRANSPOSED - if (mat2_packed_dim == W_DIM) { - texel = matmul_naive_k_dim_packed(out_lpos); - } else { - texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos); - } -#else - if (mat2_packed_dim == W_DIM) { - texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos); - } else { - texel = matmul_naive_k_dim_packed(out_lpos); - } -#endif // MAT2_IS_TRANSPOSED - -#ifdef HAS_BIAS - vec4 bias_texel = get_bias_texel_W_packed(out_lpos); - texel = beta * bias_texel + alpha * texel; -#endif // HAS_BIAS - - write_texel_lpos(out_tensor, out_lpos, texel, out_axis_map); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.yaml deleted file mode 100644 index 33b617eed13..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -addmm_naive_texture3d: - parameter_names_with_default_values: - DTYPE: float - MAT2_IS_TRANSPOSED: false - HAS_BIAS: true - generate_variant_forall: - DTYPE: - - VALUE: float - - VALUE: half - shader_variants: - - NAME: addmm_naive_texture3d - - NAME: matmul_naive_texture3d - HAS_BIAS: false - - NAME: linear_naive_texture3d - MAT2_IS_TRANSPOSED: true - - NAME: matmul_transposed_naive_texture3d - MAT2_IS_TRANSPOSED: true - HAS_BIAS: false diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl deleted file mode 100644 index 05c227f302c..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl +++ /dev/null @@ -1,242 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -$if MAT2_IS_TRANSPOSED: - #define MAT2_IS_TRANSPOSED - -$if BATCH_MODE: - #define BATCH_MODE - -$if HAS_BIAS: - #define HAS_BIAS - -${layout_declare_tensor(B, "w", "out_tensor", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "mat1_tensor", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "mat2_tensor", DTYPE, "texture3d")} -$if HAS_BIAS: - ${layout_declare_tensor(B, "r", "bias_tensor", DTYPE, "texture3d")} -${layout_declare_ubo(B, "ivec4", "out_sizes")} -${layout_declare_ubo(B, "ivec4", "mat1_sizes")} -${layout_declare_ubo(B, "ivec4", "mat2_sizes")} -$if HAS_BIAS: - ${layout_declare_ubo(B, "ivec4", "bias_sizes")} - ${layout_declare_ubo(B, "float", "alpha", "float", "beta")} - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int out_packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "mat1_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 mat1_axis_map = unhash_axis_map(mat1_layout); - -${layout_declare_spec_const(C, "int", "mat2_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 mat2_axis_map = unhash_axis_map(mat2_layout); - -$if HAS_BIAS: - ${layout_declare_spec_const(C, "int", "bias_layout", "DEFAULT_LAYOUT")} - const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout); - -// To convince the SPIR-V compiler to unroll the loops optimally, need this -// macro -#define FOUR 4 - -#define TILE_ROWS ${TILE_ROWS} - -// we avoid mat4 and vec4 usage here as they compile to much less efficient -// SPIR-V -struct FloatMatrix_2d { - float data[TILE_ROWS][FOUR]; -}; - -struct FloatMatrix_3d { - float data[TILE_ROWS][FOUR][FOUR]; -}; - -#ifdef BATCH_MODE - #define FloatMatrix FloatMatrix_3d -#else - #define FloatMatrix FloatMatrix_2d -#endif // BATCH_MODE - -#ifdef HAS_BIAS -// get texel from self tensor (channel_packed) in addmm -vec4 get_texel_C_packed(const ivec2 idx) { - ivec3 bias_pos = ivec3(0); - if (bias_sizes.x > 1) { - bias_pos[bias_axis_map.x] = idx.x; - } - if (bias_sizes.y > 1) { - bias_pos[bias_axis_map.y] = idx.y; - } - - return texelFetch(bias_tensor, bias_pos, 0); -} -#endif // HAS_BIAS - -FloatMatrix matmul_partial(const ivec4 out_idx_tl) { - FloatMatrix results; - for (int i = 0; i < TILE_ROWS; i++) { - for (int j = 0; j < FOUR; j++) { -#ifdef BATCH_MODE - for (int k = 0; k < FOUR; k++) { - results.data[i][j][k] = 0.0f; - } -#else - results.data[i][j] = 0.0f; -#endif // BATCH_MODE - } - } - vec4 mat1_tensor_partial_load[TILE_ROWS]; - vec4 mat2_tensor_partial_load[FOUR]; - -#ifdef MAT2_IS_TRANSPOSED - const int mat2_k_axis = mat2_axis_map.x; - const int mat2_row_axis = mat2_axis_map.y; -#else - const int mat2_k_axis = mat2_axis_map.y; - const int mat2_row_axis = mat2_axis_map.x; -#endif // MAT2_IS_TRANSPOSED - -#ifdef BATCH_MODE - for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { - if (out_idx_tl.z + batch_idx >= out_sizes.z) { - break; - } -#endif // BATCH_MODE - for (int k = 0; k < mat1_sizes.x; k+=4) { - const int k_div4 = k >> 2; - // read and cache (4 x TILE_ROWS) tile of mat1 - for (int r = 0; r < TILE_ROWS; r++) { - ivec3 mat1_pos = ivec3(0); - mat1_pos[mat1_axis_map.x] = k_div4; - mat1_pos[mat1_axis_map.y] = out_idx_tl.y + r; -#ifdef BATCH_MODE - mat1_pos[mat1_axis_map.z] = out_idx_tl.z + batch_idx; -#endif // BATCH_MODE - - mat1_tensor_partial_load[r] = texelFetch(mat1_tensor, mat1_pos, 0); - } - - // read and cache (4 x 4) tile of mat2 - for (int r = 0; r < FOUR; ++r) { - ivec3 mat2_pos = ivec3(0); - mat2_pos[mat2_k_axis] = k_div4; - mat2_pos[mat2_row_axis] = out_idx_tl.x + r; -#if defined(BATCH_MODE) && !defined(MAT2_IS_TRANSPOSED) - mat2_pos[mat2_axis_map.z] = out_idx_tl.z + batch_idx; -#endif // BATCH_MODE - - mat2_tensor_partial_load[r] = texelFetch(mat2_tensor, mat2_pos, 0); - } - - // perform partial dot products and add partial result to results - for (int out_row = 0; out_row < TILE_ROWS; out_row++) { - for (int out_col = 0; out_col < FOUR; out_col++) { -#ifdef BATCH_MODE - results.data[out_row][out_col][batch_idx] += -#else - results.data[out_row][out_col] += -#endif // BATCH_MODE - dot(mat1_tensor_partial_load[out_row], mat2_tensor_partial_load[out_col]); - } - } - } -#ifdef BATCH_MODE - } -#endif // BATCH_MODE - - return results; -} - -// -// Write result matrix to output (3D matmul) -// - -void write_results_C_packed(const ivec4 out_idx_tl, FloatMatrix results) { - ivec3 out_pos = tidx_to_pos( - out_idx_tl, out_sizes, out_axis_map, out_packed_dim); - - for (int tile_c = 0; - tile_c < TILE_ROWS; - tile_c++, out_pos[out_axis_map.y]++) { - out_pos[out_axis_map.x] = out_idx_tl.x; - - for (int tile_r = 0; - tile_r < FOUR; - tile_r++, out_pos[out_axis_map.x]++) { - -#ifdef HAS_BIAS - ivec2 bias_idx; - bias_idx[bias_axis_map.x] = out_pos[out_axis_map.x]; - bias_idx[bias_axis_map.y] = out_pos[out_axis_map.y]; - float bias_val = get_texel_C_packed(bias_idx).x; -#ifdef BATCH_MODE - vec4 bias_texel = vec4(bias_val); -#else - vec4 bias_texel = vec4(bias_val, 0, 0, 0); -#endif // BATCH_MODE -#endif // HAS_BIAS - -#ifdef BATCH_MODE - vec4 out_texel = vec4( - results.data[tile_c][tile_r][0], - results.data[tile_c][tile_r][1], - results.data[tile_c][tile_r][2], - results.data[tile_c][tile_r][3]); -#else - vec4 out_texel = vec4( - results.data[tile_c][tile_r], - 0.0, - 0.0, - 0.0); -#endif // BATCH_MODE - -#ifdef HAS_BIAS - imageStore(out_tensor, out_pos, beta * bias_texel + alpha * out_texel); -#else - imageStore(out_tensor, out_pos, out_texel); -#endif // HAS_BIAS - } - } -} - -void main() { - // Each thread is responsible for calculating a (4 x TILE_ROWS x 1) tile of - // output elements. If the input matrices are 3D, then a (4 x TILE_ROWS x 4) - // tile of output elements will be computed. Note the sizes are written in - // (W x H x C) format. - const ivec3 tile_idx = ivec3(gl_GlobalInvocationID); - - // Calculate the tensor index of the top left element in the output tile - const ivec4 out_idx_topleft = ivec4( - tile_idx.x * 4, - tile_idx.y * TILE_ROWS, -#ifdef BATCH_MODE - tile_idx.z * 4, -#else - tile_idx.z, -#endif // BATCH_MODE - 0); - - // If the top left element is already out of range, then skip - if (any(greaterThanEqual(out_idx_topleft, out_sizes))) { - return; - } - - FloatMatrix results = matmul_partial(out_idx_topleft); - - write_results_C_packed(out_idx_topleft, results); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml deleted file mode 100644 index c82c2003d20..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -addmm_optimized: - parameter_names_with_default_values: - DTYPE: float - MAT2_IS_TRANSPOSED: false - BATCH_MODE: false - TILE_ROWS: 4 - HAS_BIAS: true - generate_variant_forall: - TILE_ROWS: - - VALUE: 4 - SUFFIX: tile_row_4 - - VALUE: 2 - SUFFIX: tile_row_2 - DTYPE: - - VALUE: float - - VALUE: half - shader_variants: - - NAME: addmm_optimized - - NAME: matmul_optimized - HAS_BIAS: false - - NAME: linear_optimized - MAT2_IS_TRANSPOSED: true - - NAME: matmul_transposed_optimized - MAT2_IS_TRANSPOSED: true - HAS_BIAS: false - - NAME: batch_addmm_optimized - BATCH_MODE: true - - NAME: batch_matmul_optimized - BATCH_MODE: true - HAS_BIAS: false - - NAME: batch_linear_optimized - MAT2_IS_TRANSPOSED: true - BATCH_MODE: true - - NAME: batch_matmul_transposed_optimized - MAT2_IS_TRANSPOSED: true - BATCH_MODE: true - HAS_BIAS: false diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh new file mode 100644 index 00000000000..36b2a7296ef --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Tile loader for prepacked fp linear weights in 4OC x 4IC blocked layout. + * + * Assume the following variables are defined in the shader layout: + * - t_weight_packed + * + * Macro Settings: + * - WEIGHT_BUFFER + */ + +#ifndef LINEAR_FP_PACKED_WEIGHT_TILE_LOAD_GLSLH +#define LINEAR_FP_PACKED_WEIGHT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_weight_tile.glslh" + +VEC4_T load_packed_weight_x4( + const int n4, const int dk, const int k4, const int b, const int K4, const int N4) { +#ifdef WEIGHT_BUFFER + return t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]; +#else + return VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0)); +#endif +} + +void load_packed_weight_tile_no_checks( + out FPWeightTile tile, + const int n4_start, + const int k4_start, + const int b, + const int N4, + const int K4) { + [[unroll]] for (int dk4 = 0; dk4 < TILE_K4; dk4++) { + const int k4 = k4_start + dk4; + [[unroll]] for (int dk = 0; dk < 4; dk++) { + const int k = dk4 * 4 + dk; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + tile.data[k][n4] = load_packed_weight_x4(n4_start + n4, dk, k4, b, K4, N4); + } + } + } +} + +void load_packed_weight_tile_with_checks( + out FPWeightTile tile, + const int n4_start, + const int k4_start, + const int b, + const int N4, + const int K4) { + [[unroll]] for (int dk4 = 0; dk4 < TILE_K4; dk4++) { + const int k4 = k4_start + dk4; + [[unroll]] for (int dk = 0; dk < 4; dk++) { + const int k = dk4 * 4 + dk; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + if (k4 < K4 && n4_start + n4 < N4) { + tile.data[k][n4] = load_packed_weight_x4(n4_start + n4, dk, k4, b, K4, N4); + } else { + tile.data[k][n4] = VEC4_T(0); + } + } + } + } +} + +#endif // LINEAR_FP_PACKED_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_scalar.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_scalar.glsl new file mode 100644 index 00000000000..9f5c2ee6a50 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_scalar.glsl @@ -0,0 +1,106 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +#define OUTPUT_BUFFER +#define INPUT_BUFFER +#define SCALAR_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER +$if HAS_BIAS: + #define HAS_BIAS + #define BIAS_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(STORAGE, DTYPE)} +$if WEIGHT_STORAGE != STORAGE: + ${define_required_extensions(WEIGHT_STORAGE, DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, WEIGHT_STORAGE, is_scalar_array=False)} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=True)} + +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} + +$if HAS_BIAS: + layout(push_constant) uniform restrict Block { + int weight_B; + float alpha; + float beta; + }; +$else: + layout(push_constant) uniform restrict Block { + int weight_B; + }; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "matmul_fp_mat1_tile_load.glslh" +#include "linear_fp_packed_weight_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "matmul_fp_out_tile_store.glslh" +#include "matmul_fp_bias_apply.glslh" + +void main() { + const int tile_idx_n = int(gl_GlobalInvocationID.x); + const int tile_idx_m = int(gl_GlobalInvocationID.y); + + const int n4_start = tile_idx_n * TILE_N4; + const int m_start = tile_idx_m * TILE_M; + + const int K = mat1_sizes.x; + const int M = mat1_sizes.y; // mat1 [M, K] in WHCN = {K, M, 1, 1} + const int K4 = div_up_4(K); + const int N = out_sizes.x; + const int N4 = div_up_4(N); + + if (n4_start >= N4 || m_start >= M) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + FPWeightTile w_tile; + + const int b = int(gl_GlobalInvocationID.z); + + for (int k4 = 0; k4 < K4; k4++) { + load_mat1_tile_scalar(in_tile, k4, m_start, b, K4, K, M); + load_packed_weight_tile_with_checks(w_tile, n4_start, k4, b % weight_B, N4, K4); + fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile); + } + +#ifdef HAS_BIAS + apply_bias(out_tile, n4_start, m_start, N4, N); +#endif + + store_matmul_out_tile_scalar(out_tile, n4_start, m_start, b, N4, N, M); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_scalar.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_scalar.yaml new file mode 100644 index 00000000000..460c36dc967 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_scalar.yaml @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_scalar: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + WEIGHT_STORAGE: texture2d + HAS_BIAS: false + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + TILE_M: 4 + generate_variant_forall: + combination: + parameter_names: [STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: linear + - NAME: linear_tile_row_2 + TILE_M: 2 + - NAME: linear_tile_row_1 + TILE_M: 1 + - NAME: linear_bias + HAS_BIAS: true + - NAME: linear_bias_tile_row_2 + HAS_BIAS: true + TILE_M: 2 + - NAME: linear_bias_tile_row_1 + HAS_BIAS: true + TILE_M: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_vec.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_vec.glsl new file mode 100644 index 00000000000..adf5272673e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_vec.glsl @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +$if STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER +$if HAS_BIAS: + #define HAS_BIAS +$if STORAGE == "buffer" and HAS_BIAS: + #define BIAS_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(STORAGE, DTYPE)} +$if WEIGHT_STORAGE != STORAGE: + ${define_required_extensions(WEIGHT_STORAGE, DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, WEIGHT_STORAGE, is_scalar_array=False)} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} + +$if HAS_BIAS: + layout(push_constant) uniform restrict Block { + int weight_B; + float alpha; + float beta; + }; +$else: + layout(push_constant) uniform restrict Block { + int weight_B; + }; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "matmul_fp_mat1_tile_load.glslh" +#include "linear_fp_packed_weight_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "matmul_fp_out_tile_store.glslh" +#include "matmul_fp_bias_apply.glslh" + +void main() { + const int tile_idx_n = int(gl_GlobalInvocationID.x); + const int tile_idx_m = int(gl_GlobalInvocationID.y); + + const int n4_start = tile_idx_n * TILE_N4; + const int m_start = tile_idx_m * TILE_M; + + const int K = mat1_sizes.x; + const int M = mat1_sizes.y; // mat1 [M, K] in WHCN = {K, M, 1, 1} + const int K4 = div_up_4(K); + const int N = out_sizes.x; + const int N4 = div_up_4(N); + + if (n4_start >= N4 || m_start >= M) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + FPWeightTile w_tile; + + const int b = int(gl_GlobalInvocationID.z); + + for (int k4 = 0; k4 < K4; k4++) { + load_mat1_tile_with_checks(in_tile, k4, m_start, b, K4, M); + load_packed_weight_tile_with_checks(w_tile, n4_start, k4, b % weight_B, N4, K4); + fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile); + } + +#ifdef HAS_BIAS + apply_bias_vec(out_tile, n4_start, m_start, N4, N); +#endif + + store_matmul_out_tile_with_checks(out_tile, n4_start, m_start, b, N4, M); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_vec.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_vec.yaml new file mode 100644 index 00000000000..ec85c1ab2e4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_vec.yaml @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_vec: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + WEIGHT_STORAGE: texture2d + HAS_BIAS: false + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + TILE_M: 4 + generate_variant_forall: + combination: + parameter_names: [STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [texture3d, texture2d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, texture2d] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: linear_vec + - NAME: linear_vec_tile_row_2 + TILE_M: 2 + - NAME: linear_vec_tile_row_1 + TILE_M: 1 + - NAME: linear_vec_bias + HAS_BIAS: true + - NAME: linear_vec_bias_tile_row_2 + HAS_BIAS: true + TILE_M: 2 + - NAME: linear_vec_bias_tile_row_1 + HAS_BIAS: true + TILE_M: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_bias_apply.glslh b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_bias_apply.glslh new file mode 100644 index 00000000000..e8a993fdaaa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_bias_apply.glslh @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_bias + * - bias_sizes + * - alpha, beta (push constants) + * + * Macro Settings: + * - HAS_BIAS + * - SCALAR_BUFFER + * - BIAS_BUFFER + */ + +#ifndef MATMUL_FP_BIAS_APPLY_GLSLH +#define MATMUL_FP_BIAS_APPLY_GLSLH + +#ifdef HAS_BIAS + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +#ifndef SCALAR_BUFFER +void apply_bias_vec( + inout FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4, + const int N) { + const int b_N = bias_sizes.x; + const int b_M = bias_sizes.y; + const int b_N4 = div_up_4(b_N); + + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + const int bias_row = min(m_start + m, b_M - 1); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + VEC4_T bias_val = VEC4_T(0.0); + if (n4_start + n4 < N4) { + const int bias_n4 = min(n4_start + n4, b_N4 - 1); +#ifdef BIAS_BUFFER + bias_val = t_bias[bias_row * b_N4 + bias_n4]; +#else + bias_val = texelFetch(t_bias, ivec3(bias_n4, bias_row, 0), 0); +#endif + if (b_N == 1) { + bias_val = VEC4_T(bias_val.x); + } + } + out_tile.data[m][n4] = + VEC4_T(alpha) * out_tile.data[m][n4] + VEC4_T(beta) * bias_val; + } + } +} +#endif // !SCALAR_BUFFER + +#ifdef SCALAR_BUFFER +void apply_bias( + inout FPOutTile out_tile, + const int n4_start, + const int m_start, + const int N4, + const int N) { + const int b_N = bias_sizes.x; + const int b_M = bias_sizes.y; + + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + const int bias_row = min(m_start + m, b_M - 1); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + VEC4_T bias_val = VEC4_T(0.0); + if (n4_start + n4 < N4) { + const int col = mul_4(n4_start + n4); + const int bc0 = min(col, b_N - 1); + const int bc1 = min(col + 1, b_N - 1); + const int bc2 = min(col + 2, b_N - 1); + const int bc3 = min(col + 3, b_N - 1); + T b0 = t_bias[bias_row * b_N + bc0]; + T b1 = (col + 1 < N) ? t_bias[bias_row * b_N + bc1] : T(0); + T b2 = (col + 2 < N) ? t_bias[bias_row * b_N + bc2] : T(0); + T b3 = (col + 3 < N) ? t_bias[bias_row * b_N + bc3] : T(0); + bias_val = VEC4_T(b0, b1, b2, b3); + } + out_tile.data[m][n4] = + VEC4_T(alpha) * out_tile.data[m][n4] + VEC4_T(beta) * bias_val; + } + } +} +#endif // SCALAR_BUFFER + +#endif // HAS_BIAS + +#endif // MATMUL_FP_BIAS_APPLY_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat1_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat1_tile_load.glslh new file mode 100644 index 00000000000..a6ce9c0af76 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat1_tile_load.glslh @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_mat1 + * + * Macro Settings: + * - INPUT_BUFFER + */ + +#ifndef MATMUL_FP_MAT1_TILE_LOAD_GLSLH +#define MATMUL_FP_MAT1_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +#ifndef SCALAR_BUFFER +VEC4_T load_mat1_x4( + const int k4, + const int m, + const int b, + const int K4, + const int M) { +#ifdef INPUT_BUFFER + return t_mat1[(b * M + m) * K4 + k4]; +#else + return texelFetch(t_mat1, ivec3(k4, m, b), 0); +#endif +} + +void load_mat1_tile_no_checks( + out FPInputTile tile, + const int k4_start, + const int m_start, + const int b, + const int K4, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + tile.data[m][k4] = + load_mat1_x4(k4_start + k4, m_start + m, b, K4, M); + } + } +} + +void load_mat1_tile_with_checks( + out FPInputTile tile, + const int k4_start, + const int m_start, + const int b, + const int K4, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (k4_start + k4 < K4 && m_start + m < M) { + tile.data[m][k4] = + load_mat1_x4(k4_start + k4, m_start + m, b, K4, M); + } else { + tile.data[m][k4] = VEC4_T(0.0); + } + } + } +} +#endif // !SCALAR_BUFFER + +#ifdef SCALAR_BUFFER +void load_mat1_tile_scalar( + out FPInputTile tile, + const int k4_start, + const int m_start, + const int b, + const int K4, + const int K, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (k4_start + k4 < K4 && m_start + m < M) { + const int base = (b * M + m_start + m) * K + mul_4(k4_start + k4); + T s0 = t_mat1[base]; + T s1 = (mul_4(k4_start + k4) + 1 < K) ? t_mat1[base + 1] : T(0); + T s2 = (mul_4(k4_start + k4) + 2 < K) ? t_mat1[base + 2] : T(0); + T s3 = (mul_4(k4_start + k4) + 3 < K) ? t_mat1[base + 3] : T(0); + tile.data[m][k4] = VEC4_T(s0, s1, s2, s3); + } else { + tile.data[m][k4] = VEC4_T(0.0); + } + } + } +} +#endif // SCALAR_BUFFER + +#endif // MATMUL_FP_MAT1_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat2_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat2_tile_load.glslh new file mode 100644 index 00000000000..8da837e3097 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_mat2_tile_load.glslh @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_mat2 + * + * Macro Settings: + * - MAT2_BUFFER + */ + +#ifndef MATMUL_FP_MAT2_TILE_LOAD_GLSLH +#define MATMUL_FP_MAT2_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_weight_tile.glslh" + +#ifndef SCALAR_BUFFER +VEC4_T load_mat2_x4( + const int n4, + const int k, + const int b, + const int N4, + const int K) { +#ifdef MAT2_BUFFER + return t_mat2[(b * K + k) * N4 + n4]; +#else + return texelFetch(t_mat2, ivec3(n4, k, b), 0); +#endif +} + +void load_mat2_tile_no_checks( + out FPWeightTile tile, + const int n4_start, + const int k_start, + const int b, + const int N4, + const int K) { + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[k][n4] = + load_mat2_x4(n4_start + n4, k_start + k, b, N4, K); + } + } +} + +void load_mat2_tile_with_checks( + out FPWeightTile tile, + const int n4_start, + const int k_start, + const int b, + const int N4, + const int K) { + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (k_start + k < K && n4_start + n4 < N4) { + tile.data[k][n4] = + load_mat2_x4(n4_start + n4, k_start + k, b, N4, K); + } else { + tile.data[k][n4] = VEC4_T(0.0); + } + } + } +} +#endif // !SCALAR_BUFFER + +#ifdef SCALAR_BUFFER +void load_mat2_tile_scalar( + out FPWeightTile tile, + const int n4_start, + const int k_start, + const int b, + const int N4, + const int N, + const int K) { + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (k_start + k < K && n4_start + n4 < N4) { + const int base = (b * K + k_start + k) * N + mul_4(n4_start + n4); + T s0 = t_mat2[base]; + T s1 = (mul_4(n4_start + n4) + 1 < N) ? t_mat2[base + 1] : T(0); + T s2 = (mul_4(n4_start + n4) + 2 < N) ? t_mat2[base + 2] : T(0); + T s3 = (mul_4(n4_start + n4) + 3 < N) ? t_mat2[base + 3] : T(0); + tile.data[k][n4] = VEC4_T(s0, s1, s2, s3); + } else { + tile.data[k][n4] = VEC4_T(0.0); + } + } + } +} +#endif // SCALAR_BUFFER + +#endif // MATMUL_FP_MAT2_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_out_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_out_tile_store.glslh new file mode 100644 index 00000000000..e6ee4b5a8ac --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_fp_out_tile_store.glslh @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_output + * - t_bias (when HAS_BIAS is defined) + * + * Macro Settings: + * - OUTPUT_BUFFER + * - HAS_BIAS + * - BIAS_BUFFER + */ + +#ifndef MATMUL_FP_OUT_TILE_STORE_GLSLH +#define MATMUL_FP_OUT_TILE_STORE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" +#include "linear_fp_per_out_channel_params.glslh" + +#ifndef SCALAR_BUFFER +void store_matmul_output_x4( + const VEC4_T texel, + const int n4, + const int m, + const int b, + const int N4, + const int M) { +#ifdef OUTPUT_BUFFER + t_output[(b * M + m) * N4 + n4] = texel; +#else + imageStore(t_output, ivec3(n4, m, b), texel); +#endif +} + +void store_matmul_out_tile_with_checks( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int b, + const int N4, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + store_matmul_output_x4( + out_tile.data[m][n4], n4_start + n4, m_start + m, b, N4, M); + } + } + } +} +#endif // !SCALAR_BUFFER + +#ifdef SCALAR_BUFFER +void store_matmul_out_tile_scalar( + const FPOutTile out_tile, + const int n4_start, + const int m_start, + const int b, + const int N4, + const int N, + const int M) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + const int base = (b * M + m_start + m) * N + mul_4(n4_start + n4); + const VEC4_T val = out_tile.data[m][n4]; + t_output[base] = val.x; + if (mul_4(n4_start + n4) + 1 < N) t_output[base + 1] = val.y; + if (mul_4(n4_start + n4) + 2 < N) t_output[base + 2] = val.z; + if (mul_4(n4_start + n4) + 3 < N) t_output[base + 3] = val.w; + } + } + } +} +#endif // SCALAR_BUFFER + +#ifdef HAS_BIAS + +#ifndef SCALAR_BUFFER +void load_bias_params( + out FPPerOutChannelParams params, + const int n4_start, + const int N4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (n4_start + n4 < N4) { +#ifdef BIAS_BUFFER + params.data[n4] = t_bias[n4_start + n4]; +#else + params.data[n4] = texelFetch(t_bias, ivec3(n4_start + n4, 0, 0), 0); +#endif + } else { + params.data[n4] = VEC4_T(0.0); + } + } +} +#endif // !SCALAR_BUFFER + +void apply_addmm_bias( + inout FPOutTile out_tile, + const FPPerOutChannelParams bias) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = + VEC4_T(alpha) * out_tile.data[m][n4] + VEC4_T(beta) * bias.data[n4]; + } + } +} + +#endif // HAS_BIAS + +#endif // MATMUL_FP_OUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.glsl new file mode 100644 index 00000000000..a3f68dda839 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.glsl @@ -0,0 +1,101 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +#define OUTPUT_BUFFER +#define INPUT_BUFFER +#define MAT2_BUFFER +#define SCALAR_BUFFER +$if HAS_BIAS: + #define HAS_BIAS + #define BIAS_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(STORAGE, DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_mat2", DTYPE, STORAGE, is_scalar_array=True)} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=True)} + +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} +${layout_declare_ubo(B, "ivec4", "mat2_sizes")} + +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} + +$if HAS_BIAS: + layout(push_constant) uniform restrict Block { + float alpha; + float beta; + }; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "matmul_fp_mat1_tile_load.glslh" +#include "matmul_fp_mat2_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "matmul_fp_out_tile_store.glslh" +#include "matmul_fp_bias_apply.glslh" + +void main() { + const int tile_idx_n = int(gl_GlobalInvocationID.x); + const int tile_idx_m = int(gl_GlobalInvocationID.y); + const int b = int(gl_GlobalInvocationID.z); + + const int n4_start = tile_idx_n * TILE_N4; + const int m_start = tile_idx_m * TILE_M; + + // mat1 sizes in WHCN order: {K, M, B, 1} + const int K = mat1_sizes.x; + const int M = mat1_sizes.y; + const int K4 = div_up_4(K); + + // mat2 sizes in WHCN order: {N, K, B, 1} + const int N = mat2_sizes.x; + const int N4 = div_up_4(N); + + if (n4_start >= N4 || m_start >= M) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + FPWeightTile w_tile; + + for (int k4 = 0; k4 < K4; k4++) { + load_mat1_tile_scalar(in_tile, k4, m_start, b, K4, K, M); + load_mat2_tile_scalar(w_tile, n4_start, mul_4(k4), b, N4, N, K); + fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile); + } + +#ifdef HAS_BIAS + apply_bias(out_tile, n4_start, m_start, N4, N); +#endif + + store_matmul_out_tile_scalar(out_tile, n4_start, m_start, b, N4, N, M); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.yaml new file mode 100644 index 00000000000..1003532bfc9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_scalar.yaml @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +matmul_scalar: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + HAS_BIAS: false + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + TILE_M: 4 + generate_variant_forall: + STORAGE: + - VALUE: buffer + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: matmul + - NAME: matmul_tile_row_2 + TILE_M: 2 + - NAME: matmul_tile_row_1 + TILE_M: 1 + - NAME: matmul_bias + HAS_BIAS: true + - NAME: matmul_bias_tile_row_2 + HAS_BIAS: true + TILE_M: 2 + - NAME: matmul_bias_tile_row_1 + HAS_BIAS: true + TILE_M: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_vec.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_vec.glsl new file mode 100644 index 00000000000..9d4f6f04aba --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_vec.glsl @@ -0,0 +1,102 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +$if STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER + #define MAT2_BUFFER +$if HAS_BIAS: + #define HAS_BIAS +$if STORAGE == "buffer" and HAS_BIAS: + #define BIAS_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(STORAGE, DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat2", DTYPE, STORAGE, is_scalar_array=False)} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} +${layout_declare_ubo(B, "ivec4", "mat2_sizes")} + +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} + +$if HAS_BIAS: + layout(push_constant) uniform restrict Block { + float alpha; + float beta; + }; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "matmul_fp_mat1_tile_load.glslh" +#include "matmul_fp_mat2_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "matmul_fp_out_tile_store.glslh" +#include "matmul_fp_bias_apply.glslh" + +void main() { + const int tile_idx_n = int(gl_GlobalInvocationID.x); + const int tile_idx_m = int(gl_GlobalInvocationID.y); + const int b = int(gl_GlobalInvocationID.z); + + const int n4_start = tile_idx_n * TILE_N4; + const int m_start = tile_idx_m * TILE_M; + + // mat1 sizes in WHCN order: {K, M, B, 1} + const int K = mat1_sizes.x; + const int M = mat1_sizes.y; + const int K4 = div_up_4(K); + + // mat2 sizes in WHCN order: {N, K, B, 1} + const int N = mat2_sizes.x; + const int N4 = div_up_4(N); + + if (n4_start >= N4 || m_start >= M) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + FPWeightTile w_tile; + + for (int k4 = 0; k4 < K4; k4++) { + load_mat1_tile_with_checks(in_tile, k4, m_start, b, K4, M); + load_mat2_tile_with_checks(w_tile, n4_start, mul_4(k4), b, N4, K); + fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile); + } + +#ifdef HAS_BIAS + apply_bias_vec(out_tile, n4_start, m_start, N4, N); +#endif + + store_matmul_out_tile_with_checks(out_tile, n4_start, m_start, b, N4, M); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_vec.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_vec.yaml new file mode 100644 index 00000000000..ba5e757bb27 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_vec.yaml @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +matmul_vec: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + HAS_BIAS: false + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + TILE_M: 4 + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: matmul_vec + - NAME: matmul_vec_tile_row_2 + TILE_M: 2 + - NAME: matmul_vec_tile_row_1 + TILE_M: 1 + - NAME: matmul_vec_bias + HAS_BIAS: true + - NAME: matmul_vec_bias_tile_row_2 + HAS_BIAS: true + TILE_M: 2 + - NAME: matmul_vec_bias_tile_row_1 + HAS_BIAS: true + TILE_M: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl new file mode 100644 index 00000000000..8976f4b8d69 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl @@ -0,0 +1,125 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if PACKED_STORAGE == "buffer": + #define OUTPUT_BUFFER + +#extension GL_EXT_control_flow_attributes : require + +${define_required_extensions("buffer", DTYPE)} +$if PACKED_STORAGE != "buffer": + ${define_required_extensions(PACKED_STORAGE, DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +$if PACKED_STORAGE == "buffer": + ${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, "buffer", is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, PACKED_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_src", DTYPE, "buffer", is_scalar_array=True)} + +layout(push_constant) uniform restrict Block { + int N; + int K; + int B; + int is_transposed; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Packs fp linear weight into 4OC x 4IC blocked layout. +// +// Source data is contiguous row-major with no per-row padding, so scalar reads +// are used to correctly handle dimensions that are not multiples of 4. +// +// When is_transposed != 0, source is [B, N, K] row-major (transposed weight): +// Scalar at (b, n, k) = t_weight_src[b * N * K + n * K + k] +// Each 4x4 block is transposed so that: +// packed[dk] = {w[k4*4+dk][n4*4+0..3]} +// +// When is_transposed == 0, source is [B, K, N] row-major (non-transposed): +// Scalar at (b, k, n) = t_weight_src[b * K * N + k * N + n] +// Already in the desired column grouping, no transpose needed. +// +// Output: batch-stacked blocked layout indexed by (b, k4, n4, dk). + +T load_scalar(const int idx) { + return T(t_weight_src[idx]); +} + +VEC4_T load_scalar_row(const int row_base, const int col, const int max_col) { + return VEC4_T( + load_scalar(row_base + col), + (col + 1 < max_col) ? load_scalar(row_base + col + 1) : T(0), + (col + 2 < max_col) ? load_scalar(row_base + col + 2) : T(0), + (col + 3 < max_col) ? load_scalar(row_base + col + 3) : T(0)); +} + +void main() { + const int n4 = int(gl_GlobalInvocationID.x); + const int k4 = int(gl_GlobalInvocationID.y); + const int b = int(gl_GlobalInvocationID.z); + + const int K4 = div_up_4(K); + const int N4 = div_up_4(N); + + if (n4 >= N4 || k4 >= K4 || b >= B) { + return; + } + + if (is_transposed != 0) { + // Source is [N, K] or [B, N, K]. + // Read 4 N-rows at the k4 column block, transpose into 4OC x 4IC block. + const int batch_offset = b * N * K; + VEC4_T src_rows[4]; + [[unroll]] for (int dn = 0; dn < 4; dn++) { + int n = n4 * 4 + dn; + if (n < N) { + src_rows[dn] = load_scalar_row(batch_offset + n * K, k4 * 4, K); + } else { + src_rows[dn] = VEC4_T(0); + } + } + [[unroll]] for (int dk = 0; dk < 4; dk++) { + VEC4_T out_val = VEC4_T( + src_rows[0][dk], src_rows[1][dk], + src_rows[2][dk], src_rows[3][dk]); +#ifdef OUTPUT_BUFFER + t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk] = out_val; +#else + imageStore(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), out_val); +#endif + } + } else { + // Source is [K, N] or [B, K, N]. + // Read 4 K-rows at the n4 column block. No transpose needed. + const int batch_offset = b * K * N; + [[unroll]] for (int dk = 0; dk < 4; dk++) { + int k = k4 * 4 + dk; + VEC4_T out_val; + if (k < K) { + out_val = load_scalar_row(batch_offset + k * N, n4 * 4, N); + } else { + out_val = VEC4_T(0); + } +#ifdef OUTPUT_BUFFER + t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk] = out_val; +#else + imageStore(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), out_val); +#endif + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml similarity index 70% rename from backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.yaml rename to backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml index b093d0c80b2..34793634435 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml @@ -4,16 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -addmm_naive_buffer: +pack_fp_linear_weight: parameter_names_with_default_values: DTYPE: float - STORAGE: buffer - HAS_BIAS: false + PACKED_STORAGE: texture2d generate_variant_forall: + PACKED_STORAGE: + - VALUE: texture2d + - VALUE: buffer DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: matmul_naive_buffer - - NAME: addmm_naive_buffer - HAS_BIAS: true + - NAME: pack_fp_linear_weight diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 67c3f377f0c..ca7f55e85f2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -19,410 +20,242 @@ namespace vkcompute { -// Custom global workgroup size function for addmm_naive_texture -utils::uvec3 addmm_naive_texture_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - return graph->logical_limits_of(out); -} - -// Custom global workgroup size function for addmm_naive_buffer -utils::uvec3 addmm_naive_buffer_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - return { - graph->size_at(-1, out), - graph->size_at(-2, out), - graph->size_at(-3, out) * graph->size_at(-4, out)}; -} - -// Custom global workgroup size function for addmm_optimized -utils::uvec3 addmm_optimized_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - const ValueRef mat1 = args.at(1).refs.at(0); - - std::vector mat1_sizes = graph->sizes_of(mat1); - int mat1_dims = mat1_sizes.size(); - - utils::uvec3 global_size = graph->logical_limits_of(out); - - if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = utils::divup_vec(global_size, {4, 2, 1}); +ValueRef prepack_fp_linear_weight( + ComputeGraph& graph, + const ValueRef weight_data, + bool is_transposed, + int64_t B) { + std::vector weight_sizes = graph.sizes_of(weight_data); + + int64_t N, K; + if (is_transposed) { + // Source is [N, K] or [B, N, K] + N = weight_sizes.at(weight_sizes.size() - 2); + K = weight_sizes.at(weight_sizes.size() - 1); } else { - global_size = utils::divup_vec(global_size, {4, 4, 1}); + // Source is [K, N] or [B, K, N] + K = weight_sizes.at(weight_sizes.size() - 2); + N = weight_sizes.at(weight_sizes.size() - 1); } - return global_size; -} -// Custom local workgroup size function for addmm_optimized -utils::uvec3 addmm_optimized_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)args; - (void)resize_args; - return adaptive_work_group_size(global_workgroup_size); -} - -void check_addmm_args( - ComputeGraph& graph, - const ValueRef self, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef beta, - const ValueRef alpha, - const ValueRef out) { - (void)alpha; - (void)beta; - - std::vector self_sizes = graph.sizes_of(self); - std::vector mat1_sizes = graph.sizes_of(mat1); - std::vector mat2_sizes = graph.sizes_of(mat2_data); - - VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); - VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - - VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out)); + int64_t K4 = utils::div_up(K, int64_t(4)); + int64_t N4 = utils::div_up(N, int64_t(4)); + + // Packed tensor: B*K4 rows, N4*4 vec4 elements per row (batch-stacked). + // Since the tensor size is in scalars and kWidthPacked packs 4 scalars per + // texel, we need width = N4*4*4 scalars to get N4*4 texels. + int64_t output_height = B * K4; + int64_t output_width = N4 * 4 * 4; + + utils::StorageType weight_storage = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + // output_width is in scalars; texture width in texels = output_width / 4 + if (output_width / 4 > max_extent || + static_cast(output_height) > max_extent) { + weight_storage = utils::kBuffer; + } - VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes)); + ValueRef packed_weight = graph.add_tensor( + {output_height, output_width}, + graph.dtype_of(weight_data), + weight_storage, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(N4), + utils::safe_downcast(K4), + utils::safe_downcast(B)}; + + struct PackParams { + int32_t N; + int32_t K; + int32_t B; + int32_t is_transposed; + }; + PackParams pack_params{ + utils::safe_downcast(N), + utils::safe_downcast(K), + utils::safe_downcast(B), + is_transposed ? 1 : 0}; + + std::string kernel_name = "pack_fp_linear_weight"; + add_storage_type_suffix(kernel_name, weight_storage); + add_dtype_suffix(kernel_name, graph.dtype_of(weight_data)); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + weight_data, + packed_weight, + {}, + {}, + {PushConstantDataInfo(&pack_params, sizeof(PackParams))})); - if (utils::val_at(-1, self_sizes) != 1) { - VK_CHECK_COND( - utils::val_at(-1, self_sizes) == utils::val_at(-1, mat2_sizes)); - } - if (utils::val_at(-2, self_sizes) != 1) { - VK_CHECK_COND( - utils::val_at(-2, self_sizes) == utils::val_at(-2, mat1_sizes)); - } + return packed_weight; } -void resize_addmm_node( +void resize_linear_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { + const std::vector& resize_args) { const ValueRef out = args.at(0).refs.at(0); const ValueRef mat1 = args.at(1).refs.at(0); - const ValueRef mat2 = args.at(1).refs.at(1); - - const bool mat2_is_transposed = graph->get_bool(extra_args.at(0)); const std::vector mat1_sizes = graph->sizes_of(mat1); - const std::vector mat2_sizes = graph->sizes_of(mat2); - const int out_cols = utils::val_at(-2, mat1_sizes); - const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2_sizes) - : utils::val_at(-1, mat2_sizes); + int64_t M = mat1_sizes.at(mat1_sizes.size() - 2); + int64_t N = graph->get_int(resize_args.at(0)); - std::vector new_out_sizes(3); - if (mat1_sizes.size() == 2) { - new_out_sizes.resize(2); - new_out_sizes.at(0) = out_cols; - new_out_sizes.at(1) = out_rows; + if (mat1_sizes.size() >= 3) { + int64_t B = mat1_sizes.at(0); + graph->virtual_resize(out, {B, M, N}); } else { - new_out_sizes.at(0) = mat1_sizes.at(0); - new_out_sizes.at(1) = out_cols; - new_out_sizes.at(2) = out_rows; + graph->virtual_resize(out, {M, N}); } - - graph->virtual_resize(out, new_out_sizes); } -struct Params final { +struct LinearIntParams final { + int32_t weight_B; +}; + +struct LinearBiasParams final { float alpha; float beta; }; -void add_addmm_naive_texture_node( - ComputeGraph& graph, - const ValueRef self_data, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef beta, - const ValueRef alpha, - const ValueRef out, - const Params& params, - const ValueRef mat2_is_transposed) { - utils::StorageType stype = graph.storage_type_of(out); - ValueRef self = prepack_standard( - graph, self_data, stype, utils::kWidthPacked, /*passthrough = */ true); - ValueRef mat2 = prepack_standard( - graph, mat2_data, stype, utils::kHeightPacked, /*passthrough = */ true); - - std::string kernel_name = - graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive"; +vkapi::ShaderInfo pick_linear_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef input = args.at(1).refs.at(0); + const ValueRef packed_weight = args.at(1).refs.at(1); + bool has_bias = graph->get_bool(resize_args.at(1)); + uint32_t tile_m = pick_matmul_tile_m(graph, out); + + bool is_buffer = graph->storage_type_of(out) == utils::kBuffer; + // Use vec4 shader when all tensor widths are aligned to 4, even for buffers + uint32_t K = graph->size_at(-1, input); + uint32_t N = graph->size_at(-1, out); + bool use_scalar = is_buffer && (K % 4 != 0 || N % 4 != 0); + std::string base = use_scalar ? "linear" : "linear_vec"; + + std::string kernel_name; + if (has_bias) { + kernel_name = tile_m <= 1 ? base + "_bias_tile_row_1" + : tile_m <= 2 ? base + "_bias_tile_row_2" + : base + "_bias"; + } else { + kernel_name = tile_m <= 1 ? base + "_tile_row_1" + : tile_m <= 2 ? base + "_tile_row_2" + : base; + } kernel_name.reserve(kShaderNameReserve); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - addmm_naive_texture_global_wg_size, - pick_hw_square_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, - // Shader params buffers - {}, - // Push Constants - {graph.sizes_pc_of(out), - graph.sizes_pc_of(mat1), - graph.sizes_pc_of(mat2), - graph.logical_limits_pc_of(out), - graph.sizes_pc_of(self), - PushConstantDataInfo(¶ms, sizeof(params))}, - // Specialization Constants - {graph.hashed_layout_of(out), - graph.hashed_layout_of(mat1), - graph.hashed_layout_of(mat2), - graph.hashed_layout_of(self)}, - // Resize Args - {mat2_is_transposed}, - // Resizing Logic - resize_addmm_node)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + return VK_KERNEL_FROM_STR(kernel_name); } -void add_addmm_naive_buffer_node( - ComputeGraph& graph, - const ValueRef self_data, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef beta, - const ValueRef alpha, - const ValueRef out, - const Params& params, - const ValueRef mat2_is_transposed) { - (void)beta; - (void)alpha; - ValueRef mat2 = prepack_standard( - graph, - mat2_data, - graph.storage_type_of(out), - utils::kHeightPacked, - /*passthrough = */ true); - ValueRef self = prepack_standard( - graph, - self_data, - graph.storage_type_of(out), - utils::kWidthPacked, - /*passthrough = */ true); - - std::string kernel_name = "addmm_naive_buffer"; - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - utils::uvec3 global_size = { - graph.size_at(-1, out), - graph.size_at(-2, out), - graph.size_at(-3, out) * graph.size_at(-4, out)}; - - int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef && - graph.get_bool(mat2_is_transposed)) - ? 1 - : 0; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - addmm_naive_buffer_global_wg_size, - pick_hw_square_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, - // Shader params buffers - { - graph.sizes_ubo(out), - graph.strides_ubo(out), - graph.sizes_ubo(mat1), - graph.strides_ubo(mat1), - graph.sizes_ubo(mat2), - graph.strides_ubo(mat2), - graph.numel_ubo(out), - graph.create_params_buffer(params), - }, - // Push Constants - {}, - // Specialization Constants - {mat2_is_transposed_val}, - // Resize Args - {mat2_is_transposed}, - // Resizing Logic - resize_addmm_node)); +utils::uvec3 pick_linear_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + uint32_t N = graph->size_at(-1, out); + uint32_t M = graph->size_at(-2, out); + uint32_t B = graph->dim_of(out) >= 3 ? graph->size_at(-3, out) : 1; + uint32_t tile_m = pick_matmul_tile_m(graph, out); + return {utils::div_up_4(N), utils::div_up(M, tile_m), B}; } -void add_addmm_optimized_node( +void add_linear_tiled_node( ComputeGraph& graph, - const ValueRef self_data, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef beta, - const ValueRef alpha, + const ValueRef input, + const ValueRef packed_weight, + const ValueRef packed_bias, + bool has_bias, const ValueRef out, - const Params& params, - const ValueRef mat2_is_transposed) { - utils::StorageType stype = graph.storage_type_of(out); - ValueRef self = prepack_standard( - graph, self_data, stype, utils::kChannelsPacked, /*passthrough=*/true); - ValueRef mat2 = prepack_standard( - graph, mat2_data, stype, utils::kHeightPacked, /*passthrough=*/true); - - // Ensure mat1 is width packed - ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - - const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed); - - // Ensure mat2 is height packed - ValueRef mat2_packed = mat2; - const utils::GPUMemoryLayout mat2_layout = - mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked; - if (graph.estimate_memory_layout_of(mat2) != mat2_layout) { - mat2_packed = graph.add_tensor_like(mat2, mat2_layout); - viewFn(graph, {mat2, graph.add_none(), mat2_packed}); + int32_t weight_B, + float alpha, + float beta) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + std::vector out_sizes = graph.sizes_of(out); + int32_t orig_N = utils::safe_downcast(out_sizes.back()); + + LinearIntParams int_params{weight_B}; + LinearBiasParams bias_params{alpha, beta}; + ValueRef has_bias_ref = graph.add_scalar(has_bias); + ValueRef orig_N_ref = graph.add_scalar(static_cast(orig_N)); + + std::vector read_inputs = {input, packed_weight}; + if (has_bias) { + read_inputs.push_back(packed_bias); } - std::string kernel_name = graph.get_bool(mat2_is_transposed) - ? "linear_optimized" - : "addmm_optimized"; - - std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); - int mat1_dims = mat1_sizes.size(); - if (mat1_dims == 3) { - kernel_name = "batch_" + kernel_name; + std::vector push_constants = { + PushConstantDataInfo(&int_params, sizeof(LinearIntParams)), + }; + if (has_bias) { + push_constants.push_back( + PushConstantDataInfo(&bias_params, sizeof(LinearBiasParams))); } - if (mat1_sizes.at(mat1_dims - 2) < 8) { - kernel_name += "_tile_row_2"; - } else { - kernel_name += "_tile_row_4"; + + vkapi::ParamsBindList shader_params = { + graph.sizes_ubo(input), graph.sizes_ubo(out)}; + if (has_bias) { + shader_params.append(graph.sizes_ubo(packed_bias)); } - add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - addmm_optimized_global_wg_size, - addmm_optimized_local_wg_size, + pick_linear_shader, + pick_linear_global_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs - {{out, vkapi::kWrite}, - {{mat1_W_packed, mat2_packed, self}, vkapi::kRead}}, + {{out, vkapi::kWrite}, {read_inputs, vkapi::kRead}}, // Shader params buffers - { - graph.sizes_ubo(out), - graph.sizes_ubo(mat1_W_packed), - graph.sizes_ubo(mat2_packed), - graph.sizes_ubo(self), - graph.create_params_buffer(params), - }, + shader_params, // Push Constants - {}, + push_constants, // Specialization Constants - {graph.hashed_layout_of(out), - graph.hashed_layout_of(mat1_W_packed), - graph.hashed_layout_of(mat2_packed), - graph.hashed_layout_of(self)}, + {}, // Resize Args - {mat2_is_transposed}, + {orig_N_ref, has_bias_ref}, // Resizing Logic - resize_addmm_node)); + resize_linear_node)); } -void add_addmm_node( +void linear_packed_weight( ComputeGraph& graph, - const ValueRef self, - const ValueRef mat1, - const ValueRef mat2, - const ValueRef beta, - const ValueRef alpha, - const ValueRef out, - const ValueRef mat2_is_transposed) { - float alpha_val = 1.0f; - float beta_val = 1.0f; - - if (alpha != kDummyValueRef) { - alpha_val = graph.extract_scalar(alpha); - } - if (beta != kDummyValueRef) { - beta_val = graph.extract_scalar(beta); - } - - Params params = {alpha_val, beta_val}; - if (graph.is_buffer_storage(out)) { - add_addmm_naive_buffer_node( - graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); - } else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) { - add_addmm_optimized_node( - graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); - } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) { - add_addmm_naive_texture_node( - graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); - } else { - VK_THROW("Input should be channel packed or width packed."); - } -} - -void addmm(ComputeGraph& graph, const std::vector& args) { - check_addmm_args(graph, args[0], args[1], args[2], args[3], args[4], args[5]); - ValueRef mat2_is_transposed = graph.add_scalar(false); - return add_addmm_node( - graph, - args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - mat2_is_transposed); -} - -void linear(ComputeGraph& graph, const std::vector& args) { + const std::vector& args) { ValueRef input = args.at(0); ValueRef weight_data = args.at(1); ValueRef bias = args.at(2); ValueRef out = args.at(3); - ValueRef weight = prepack_standard( - graph, - weight_data, - graph.storage_type_of(out), - utils::kWidthPacked, - /*passthrough = */ true); - ValueRef mat2_is_transposed = graph.add_scalar(true); - if (graph.val_is_none(bias)) { - return add_matmul_node(graph, input, weight, out, mat2_is_transposed); - } else { - return add_addmm_node( - graph, - bias, - input, - weight, - kDummyValueRef, - kDummyValueRef, - out, - mat2_is_transposed); + ValueRef packed_weight = prepack_fp_linear_weight( + graph, weight_data, /*is_transposed=*/true, /*B=*/1); + + ValueRef packed_bias = kDummyValueRef; + bool has_bias = graph.val_is_not_none(bias); + if (has_bias) { + packed_bias = prepack_standard( + graph, bias, graph.storage_type_of(out), utils::kWidthPacked); } + + add_linear_tiled_node( + graph, input, packed_weight, packed_bias, has_bias, out); } REGISTER_OPERATORS { - VK_REGISTER_OP(aten.addmm.default, addmm); - VK_REGISTER_OP(aten.linear.default, linear); + VK_REGISTER_OP(aten.linear.default, linear_packed_weight); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.h b/backends/vulkan/runtime/graph/ops/impl/Linear.h new file mode 100644 index 00000000000..d7efb8c8b08 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +ValueRef prepack_fp_linear_weight( + ComputeGraph& graph, + const ValueRef weight_data, + bool is_transposed, + int64_t B); + +void add_linear_tiled_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef packed_weight, + const ValueRef packed_bias, + bool has_bias, + const ValueRef out, + int32_t weight_B = 1, + float alpha = 1.0f, + float beta = 1.0f); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp deleted file mode 100644 index 6c687ec67a8..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ /dev/null @@ -1,325 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include -#include - -#include -#include - -#include - -namespace vkcompute { - -void check_matmul_args( - const ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef out) { - std::vector mat1_sizes = graph.sizes_of(mat1); - std::vector mat2_sizes = graph.sizes_of(mat2_data); - - VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); - VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - - VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out)); - - VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes)); -} - -void resize_matmul_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef out = args.at(0).refs.at(0); - const ValueRef mat1 = args.at(1).refs.at(0); - const ValueRef mat2 = args.at(1).refs.at(1); - - bool mat2_is_transposed = graph->get_bool(resize_args.at(0)); - - const std::vector mat1_sizes = graph->sizes_of(mat1); - const std::vector mat2_sizes = graph->sizes_of(mat2); - - const int out_cols = utils::val_at(-2, mat1_sizes); - const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2_sizes) - : utils::val_at(-1, mat2_sizes); - - const int64_t out_dim = graph->dim_of(out); - std::vector new_out_sizes(mat1_sizes); - new_out_sizes.at(out_dim - 1) = out_rows; - new_out_sizes.at(out_dim - 2) = out_cols; - - graph->virtual_resize(out, new_out_sizes); -} - -/** - * Custom global workgroup size function for naive buffer matmul operations. - */ -utils::uvec3 matmul_naive_buffer_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - return { - graph->size_at(-1, out), - graph->size_at(-2, out), - graph->size_at(-3, out) * graph->size_at(-4, out)}; -} - -void add_matmul_naive_buffer_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef out, - const ValueRef mat2_is_transposed) { - ValueRef mat2 = prepack_standard( - graph, - mat2_data, - graph.storage_type_of(out), - utils::kHeightPacked, - /*passthrough = */ true); - - std::string kernel_name = "matmul_naive_buffer"; - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef && - graph.get_bool(mat2_is_transposed)) - ? 1 - : 0; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - matmul_naive_buffer_global_wg_size, - pick_hw_square_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, - // Shader params buffers - { - graph.sizes_ubo(out), - graph.strides_ubo(out), - graph.sizes_ubo(mat1), - graph.strides_ubo(mat1), - graph.sizes_ubo(mat2), - graph.strides_ubo(mat2), - graph.numel_ubo(out), - }, - // Push Constants - {}, - // Specialization Constants - {mat2_is_transposed_val}, - // Resize Args - {mat2_is_transposed}, - // Resizing Logic - resize_matmul_node)); -} - -vkapi::ShaderInfo pick_matmul_naive_texture3d_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef out = args.at(0).refs.at(0); - const bool is_transposed = graph->get_bool(resize_args.at(0)); - - std::string kernel_name = - is_transposed ? "matmul_transposed_naive" : "matmul_naive"; - kernel_name.reserve(kShaderNameReserve); - add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); - add_dtype_suffix(kernel_name, graph->dtype_of(out)); - - return VK_KERNEL_FROM_STR(kernel_name); -} - -void add_matmul_naive_texture3d_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef out, - const ValueRef mat2_is_transposed) { - ValueRef mat2 = prepack_standard( - graph, - mat2_data, - graph.storage_type_of(out), - utils::kHeightPacked, - /*passthrough = */ true); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - pick_matmul_naive_texture3d_shader, - default_pick_global_wg_size, - pick_hw_square_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, - // Shader params buffers - {}, - // Push Constants - {graph.sizes_pc_of(out), - graph.sizes_pc_of(mat1), - graph.sizes_pc_of(mat2), - graph.logical_limits_pc_of(out)}, - // Specialization Constants - {graph.hashed_layout_of(out), - graph.hashed_layout_of(mat1), - graph.hashed_layout_of(mat2)}, - // Resize Args - {mat2_is_transposed}, - // Resizing Logic - resize_matmul_node)); -} - -vkapi::ShaderInfo pick_matmul_optimized_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef out = args.at(0).refs.at(0); - const ValueRef mat1_W_packed = resize_args.at(1); - const bool mat2_is_transposed_val = graph->get_bool(resize_args.at(0)); - - std::string kernel_name = mat2_is_transposed_val - ? "matmul_transposed_optimized" - : "matmul_optimized"; - - std::vector mat1_sizes = graph->sizes_of(mat1_W_packed); - size_t mat1_dims = mat1_sizes.size(); - if (mat1_dims == 3) { - kernel_name = "batch_" + kernel_name; - } - if (mat1_sizes.at(mat1_dims - 2) < 8) { - kernel_name += "_tile_row_2"; - } else { - kernel_name += "_tile_row_4"; - } - - add_dtype_suffix(kernel_name, graph->dtype_of(out)); - - return VK_KERNEL_FROM_STR(kernel_name); -} - -utils::uvec3 matmul_optimized_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef mat1_W_packed = resize_args.at(1); - - const std::vector mat1_sizes = graph->sizes_of(mat1_W_packed); - const size_t mat1_dims = mat1_sizes.size(); - - utils::uvec3 global_size = graph->logical_limits_of(out); - if (mat1_sizes.at(mat1_dims - 2) < 8) { - // Use `logical_extents` instead of `image_extents` because the workgroup - // axes need to correspond to tensor dimensions. - global_size = utils::divup_vec(global_size, {4, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {4, 4, 1}); - } - - return global_size; -} - -void add_matmul_optimized_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef out, - const ValueRef mat2_is_transposed) { - ValueRef mat2 = prepack_standard( - graph, - mat2_data, - graph.storage_type_of(out), - utils::kHeightPacked, - /*passthrough = */ true); - - // Ensure mat1 is width packed - TmpTensor mat1_tmp( - &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); - ValueRef mat1_W_packed = mat1; - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) { - mat1_W_packed = mat1_tmp; - viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - } - - const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed); - - // Ensure mat2 to height packed - ValueRef mat2_packed = mat2; - const utils::GPUMemoryLayout mat2_layout = - mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked; - TmpTensor mat2_tmp( - &graph, graph.sizes_of(mat2), graph.dtype_of(mat2), mat2_layout); - if (graph.estimate_memory_layout_of(mat2) != mat2_layout) { - mat2_packed = mat2_tmp; - viewFn(graph, {mat2, graph.add_none(), mat2_packed}); - } - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - pick_matmul_optimized_shader, - matmul_optimized_global_wg_size, - pick_hw_square_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}}, - // Shader params buffers - { - graph.sizes_ubo(out), - graph.sizes_ubo(mat1_W_packed), - graph.sizes_ubo(mat2_packed), - }, - // Push Constants - {}, - // Specialization Constants - {graph.hashed_layout_of(out), - graph.hashed_layout_of(mat1_W_packed), - graph.hashed_layout_of(mat2_packed)}, - // Resize Args - {mat2_is_transposed, mat1_W_packed}, - // Resizing Logic - resize_matmul_node)); -} - -void add_matmul_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef out, - const ValueRef mat2_is_transposed) { - if (graph.is_buffer_storage(out)) { - add_matmul_naive_buffer_node( - graph, mat1, mat2_data, out, mat2_is_transposed); - } else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) { - add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed); - } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) { - add_matmul_naive_texture3d_node( - graph, mat1, mat2_data, out, mat2_is_transposed); - } else { - VK_THROW("Input texture should be channel packed or width packed."); - } -} - -void matmul(ComputeGraph& graph, const std::vector& args) { - check_matmul_args(graph, args[0], args[1], args[2]); - const ValueRef mat2_is_transposed = graph.add_scalar(false); - return add_matmul_node(graph, args[0], args[1], args[2], mat2_is_transposed); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(aten.mm.default, matmul); - VK_REGISTER_OP(aten.bmm.default, matmul); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.h b/backends/vulkan/runtime/graph/ops/impl/MatMul.h index 38f7907f1b6..c950449a5bd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.h +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.h @@ -12,11 +12,6 @@ namespace vkcompute { -void add_matmul_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef out, - const ValueRef mat2_is_transposed); +uint32_t pick_matmul_tile_m(ComputeGraph* graph, const ValueRef out); } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Matmul.cpp b/backends/vulkan/runtime/graph/ops/impl/Matmul.cpp new file mode 100644 index 00000000000..53bb8d82e12 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Matmul.cpp @@ -0,0 +1,265 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +#include + +namespace vkcompute { + +void resize_matmul_tiled_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); + + const std::vector mat1_sizes = graph->sizes_of(mat1); + const std::vector mat2_sizes = graph->sizes_of(mat2); + + std::vector new_out_sizes(mat1_sizes); + new_out_sizes.at(new_out_sizes.size() - 1) = mat2_sizes.back(); + new_out_sizes.at(new_out_sizes.size() - 2) = + mat1_sizes.at(mat1_sizes.size() - 2); + + graph->virtual_resize(out, new_out_sizes); +} + +// Minimum number of thread groups to target for good GPU occupancy. When the +// default 4-row tiling produces fewer threads than this, a smaller tile is +// selected to increase parallelism. +static constexpr uint32_t kMinOccupancyThreads = 4096; + +// Returns the M tile size (1, 2, or 4) to use for the matmul shader. The +// largest tile that produces at least kMinOccupancyThreads thread groups is +// chosen; if even tile_m=1 doesn't meet the threshold, tile_m=1 is used. +uint32_t pick_matmul_tile_m(ComputeGraph* graph, const ValueRef out) { + uint32_t N = graph->size_at(-1, out); + uint32_t M = graph->size_at(-2, out); + uint32_t B = graph->dim_of(out) >= 3 ? graph->size_at(-3, out) : 1; + uint32_t n_groups = utils::div_up_4(N); + // Try tile_m = 4, 2, 1 in descending order; pick the first that gives + // enough threads. + for (uint32_t tile_m : {4u, 2u, 1u}) { + uint32_t total = n_groups * utils::div_up(M, tile_m) * B; + if (total >= kMinOccupancyThreads) { + return tile_m; + } + } + return 1u; +} + +vkapi::ShaderInfo pick_matmul_tiled_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + bool has_bias = graph->get_bool(resize_args.at(0)); + uint32_t tile_m = pick_matmul_tile_m(graph, out); + + bool is_buffer = graph->storage_type_of(out) == utils::kBuffer; + // Use vec4 shader when all tensor widths are aligned to 4, even for buffers + uint32_t K = graph->size_at(-1, mat1); + uint32_t N = graph->size_at(-1, out); + bool use_scalar = is_buffer && (K % 4 != 0 || N % 4 != 0); + std::string base = use_scalar ? "matmul" : "matmul_vec"; + + std::string kernel_name; + if (has_bias) { + kernel_name = tile_m <= 1 ? base + "_bias_tile_row_1" + : tile_m <= 2 ? base + "_bias_tile_row_2" + : base + "_bias"; + } else { + kernel_name = tile_m <= 1 ? base + "_tile_row_1" + : tile_m <= 2 ? base + "_tile_row_2" + : base; + } + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 pick_matmul_tiled_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + uint32_t N = graph->size_at(-1, out); + uint32_t M = graph->size_at(-2, out); + uint32_t B = graph->dim_of(out) >= 3 ? graph->size_at(-3, out) : 1; + uint32_t tile_m = pick_matmul_tile_m(graph, out); + return {utils::div_up_4(N), utils::div_up(M, tile_m), B}; +} + +void add_matmul_tiled_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(mat2) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + ValueRef has_bias_ref = graph.add_scalar(false); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_matmul_tiled_shader, + pick_matmul_tiled_global_wg_size, + pick_hw_square_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, + // Shader params buffers + {graph.sizes_ubo(mat1), graph.sizes_ubo(mat2)}, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {has_bias_ref}, + // Resizing Logic + resize_matmul_tiled_node)); +} + +struct MatmulBiasParams final { + float alpha; + float beta; +}; + +void add_addmm_tiled_node( + ComputeGraph& graph, + const ValueRef bias, + const ValueRef mat1, + const ValueRef mat2, + const ValueRef out, + float alpha_val, + float beta_val) { + VK_CHECK_COND(graph.packed_dim_of(bias) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(mat2) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + + MatmulBiasParams params{alpha_val, beta_val}; + + ValueRef has_bias_ref = graph.add_scalar(true); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_matmul_tiled_shader, + pick_matmul_tiled_global_wg_size, + pick_hw_square_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{mat1, mat2, bias}, vkapi::kRead}}, + // Shader params buffers + {graph.sizes_ubo(mat1), graph.sizes_ubo(mat2), graph.sizes_ubo(bias)}, + // Push Constants + {PushConstantDataInfo(¶ms, sizeof(params))}, + // Specialization Constants + {}, + // Resize Args + {has_bias_ref}, + // Resizing Logic + resize_matmul_tiled_node)); +} + +void matmul_tiled(ComputeGraph& graph, const std::vector& args) { + // args: mat1, mat2, out + ValueRef mat1 = args[0]; + ValueRef mat2 = args[1]; + ValueRef out = args[2]; + + if (graph.val_is_tref(mat2)) { + auto mat2_sizes = graph.sizes_of(mat2); + int64_t B = mat2_sizes.size() >= 3 ? mat2_sizes.at(0) : 1; + ValueRef packed = + prepack_fp_linear_weight(graph, mat2, /*is_transposed=*/false, B); + add_linear_tiled_node( + graph, + mat1, + packed, + kDummyValueRef, + false, + out, + utils::safe_downcast(B)); + } else { + add_matmul_tiled_node(graph, mat1, mat2, out); + } +} + +void addmm_tiled(ComputeGraph& graph, const std::vector& args) { + // args: self, mat1, mat2, beta, alpha, out + ValueRef self = args[0]; + ValueRef mat1 = args[1]; + ValueRef mat2 = args[2]; + ValueRef beta_ref = args[3]; + ValueRef alpha_ref = args[4]; + ValueRef out = args[5]; + + float alpha_val = alpha_ref != kDummyValueRef + ? graph.extract_scalar(alpha_ref) + : 1.0f; + float beta_val = + beta_ref != kDummyValueRef ? graph.extract_scalar(beta_ref) : 1.0f; + + if (graph.val_is_tref(mat2)) { + auto mat2_sizes = graph.sizes_of(mat2); + int64_t B = mat2_sizes.size() >= 3 ? mat2_sizes.at(0) : 1; + ValueRef packed = + prepack_fp_linear_weight(graph, mat2, /*is_transposed=*/false, B); + + ValueRef packed_bias = kDummyValueRef; + bool has_bias = graph.val_is_not_none(self); + if (has_bias) { + packed_bias = prepack_standard( + graph, + self, + graph.storage_type_of(out), + utils::kWidthPacked, + /*passthrough=*/true); + } + add_linear_tiled_node( + graph, + mat1, + packed, + packed_bias, + has_bias, + out, + utils::safe_downcast(B), + alpha_val, + beta_val); + } else { + ValueRef bias = prepack_standard( + graph, + self, + graph.storage_type_of(out), + utils::kWidthPacked, + /*passthrough=*/true); + add_addmm_tiled_node(graph, bias, mat1, mat2, out, alpha_val, beta_val); + } +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.mm.default, matmul_tiled); + VK_REGISTER_OP(aten.bmm.default, matmul_tiled); + VK_REGISTER_OP(aten.addmm.default, addmm_tiled); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestMatmulLinear.cpp b/backends/vulkan/test/custom_ops/impl/TestMatmulLinear.cpp new file mode 100644 index 00000000000..2763ab7d7e7 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestMatmulLinear.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { + +void test_mm(ComputeGraph& graph, const std::vector& args) { + const ValueRef mat1 = args.at(0); + const ValueRef mat2 = args.at(1); + const ValueRef impl_selector_str = args.at(2); + const ValueRef out = args.at(3); + + std::string impl_selector = graph.extract_string(impl_selector_str); + std::string op_name = "aten.mm." + impl_selector; + + VK_GET_OP_FN(op_name.c_str())(graph, {mat1, mat2, out}); +} + +void test_bmm(ComputeGraph& graph, const std::vector& args) { + const ValueRef mat1 = args.at(0); + const ValueRef mat2 = args.at(1); + const ValueRef impl_selector_str = args.at(2); + const ValueRef out = args.at(3); + + std::string impl_selector = graph.extract_string(impl_selector_str); + std::string op_name = "aten.bmm." + impl_selector; + + VK_GET_OP_FN(op_name.c_str())(graph, {mat1, mat2, out}); +} + +void test_addmm(ComputeGraph& graph, const std::vector& args) { + const ValueRef self = args.at(0); + const ValueRef mat1 = args.at(1); + const ValueRef mat2 = args.at(2); + const ValueRef beta = args.at(3); + const ValueRef alpha = args.at(4); + const ValueRef impl_selector_str = args.at(5); + const ValueRef out = args.at(6); + + std::string impl_selector = graph.extract_string(impl_selector_str); + std::string op_name = "aten.addmm." + impl_selector; + + VK_GET_OP_FN(op_name.c_str())(graph, {self, mat1, mat2, beta, alpha, out}); +} + +void test_linear(ComputeGraph& graph, const std::vector& args) { + const ValueRef input = args.at(0); + const ValueRef weight = args.at(1); + const ValueRef bias = args.at(2); + const ValueRef impl_selector_str = args.at(3); + const ValueRef out = args.at(4); + + std::string impl_selector = graph.extract_string(impl_selector_str); + std::string op_name = "aten.linear." + impl_selector; + + VK_GET_OP_FN(op_name.c_str())(graph, {input, weight, bias, out}); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.test_mm.default, test_mm); + VK_REGISTER_OP(test_etvk.test_bmm.default, test_bmm); + VK_REGISTER_OP(test_etvk.test_addmm.default, test_addmm); + VK_REGISTER_OP(test_etvk.test_linear.default, test_linear); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index ba4873af603..fef8994718f 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -99,3 +99,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d_dw") define_custom_op_test_binary("test_q8ta_linear") define_custom_op_test_binary("test_q8ta_conv2d_transposed") + define_custom_op_test_binary("test_mm") diff --git a/backends/vulkan/test/custom_ops/test_mm.cpp b/backends/vulkan/test/custom_ops/test_mm.cpp new file mode 100644 index 00000000000..aac54ad1514 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_mm.cpp @@ -0,0 +1,465 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include "utils.h" + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 256; + +struct MmConfig { + // mat1: [M, K] or [B, M, K] + // mat2: [K, N] or [B, K, N] + int64_t B; // batch size, 0 for non-batched + int64_t M; + int64_t K; + int64_t N; + bool has_bias; // true for addmm/linear + bool mat2_is_transposed; // true for linear (weight is [N, K]) + bool mat2_is_constant; // true to test prepacked linear path +}; + +struct MmShape { + int64_t B, M, K, N; +}; + +static TestCase create_mm_test_case( + const MmConfig& config, + vkapi::ScalarType dtype, + utils::StorageType storage_type, + utils::GPUMemoryLayout memory_layout) { + TestCase test_case; + + bool is_batched = config.B > 0; + bool is_perf = config.M > kRefDimSizeLimit || config.K > kRefDimSizeLimit || + config.N > kRefDimSizeLimit; + + std::string prefix = is_perf ? "PERF" : "ACCU"; + std::string storage_str = storage_type_abbrev(storage_type); + std::string layout_str = layout_abbrev(memory_layout); + std::string dtype_str = (dtype == vkapi::kHalf) ? "f16" : "f32"; + + // Determine op type string + std::string op_type; + if (config.mat2_is_transposed) { + op_type = config.has_bias ? "linear-bias" : "linear"; + } else if (config.has_bias) { + op_type = config.mat2_is_constant ? "addmm-const-mat2" : "addmm"; + } else if (is_batched) { + op_type = config.mat2_is_constant ? "bmm-const-mat2" : "bmm"; + } else { + op_type = config.mat2_is_constant ? "mm-const-mat2" : "mm"; + } + + // Build shape string + std::string shape; + if (is_batched) { + shape = "[" + std::to_string(config.B) + "," + std::to_string(config.M) + + "," + std::to_string(config.K) + "]x[" + std::to_string(config.B) + + "," + std::to_string(config.K) + "," + std::to_string(config.N) + "]"; + } else if (config.mat2_is_transposed) { + shape = "[" + std::to_string(config.M) + "," + std::to_string(config.K) + + "]x[" + std::to_string(config.N) + "," + std::to_string(config.K) + "]"; + } else { + shape = "[" + std::to_string(config.M) + "," + std::to_string(config.K) + + "]x[" + std::to_string(config.K) + "," + std::to_string(config.N) + "]"; + } + + std::string name = prefix + " " + op_type + " " + shape + " " + + storage_str + "(" + layout_str + ") " + dtype_str; + + test_case.set_name(name); + + // Determine op name - use test wrapper operators + std::string op_name; + if (is_batched) { + op_name = "test_etvk.test_bmm.default"; + } else if (config.mat2_is_transposed) { + op_name = "test_etvk.test_linear.default"; + } else if (config.has_bias) { + op_name = "test_etvk.test_addmm.default"; + } else { + op_name = "test_etvk.test_mm.default"; + } + test_case.set_operator_name(op_name); + + // mat1 + std::vector mat1_sizes; + if (is_batched) { + mat1_sizes = {config.B, config.M, config.K}; + } else { + mat1_sizes = {config.M, config.K}; + } + ValueSpec mat1( + mat1_sizes, dtype, storage_type, memory_layout, DataGenType::RANDOM); + + // mat2 - for linear, weight is [N, K] (transposed) + std::vector mat2_sizes; + if (config.mat2_is_transposed) { + mat2_sizes = {config.N, config.K}; + } else if (is_batched) { + mat2_sizes = {config.B, config.K, config.N}; + } else { + mat2_sizes = {config.K, config.N}; + } + + if (config.mat2_is_transposed) { + // For linear, weight is a constant tensor + ValueSpec mat2( + mat2_sizes, dtype, storage_type, memory_layout, DataGenType::RANDOM); + mat2.set_constant(true); + + // bias (or none) + if (config.has_bias) { + ValueSpec bias( + {config.N}, dtype, storage_type, memory_layout, DataGenType::RANDOM); + bias.set_constant(true); + + // test_etvk.test_linear.default: input, weight, bias, impl_selector, out + test_case.add_input_spec(mat1); + test_case.add_input_spec(mat2); + test_case.add_input_spec(bias); + } else { + test_case.add_input_spec(mat1); + test_case.add_input_spec(mat2); + // Use an int spec marked as none to avoid being treated as a tensor + ValueSpec none_bias(static_cast(0)); + none_bias.set_none(true); + test_case.add_input_spec(none_bias); + } + } else if (config.has_bias) { + // test_etvk.test_addmm.default: self, mat1, mat2, beta, alpha, + // impl_selector, out + ValueSpec bias( + {config.N}, dtype, storage_type, memory_layout, DataGenType::RANDOM); + ValueSpec mat2( + mat2_sizes, dtype, storage_type, memory_layout, DataGenType::RANDOM); + if (config.mat2_is_constant) { + mat2.set_constant(true); + } + + test_case.add_input_spec(bias); + test_case.add_input_spec(mat1); + test_case.add_input_spec(mat2); + // beta + test_case.add_input_spec(ValueSpec(1.0f)); + // alpha + test_case.add_input_spec(ValueSpec(1.0f)); + } else { + // test_etvk.test_mm.default or test_etvk.test_bmm.default: + // mat1, mat2, impl_selector, out + ValueSpec mat2( + mat2_sizes, dtype, storage_type, memory_layout, DataGenType::RANDOM); + if (config.mat2_is_constant) { + mat2.set_constant(true); + } + test_case.add_input_spec(mat1); + test_case.add_input_spec(mat2); + } + + // impl_selector (added before output for all variants) + ValueSpec impl_selector_spec = ValueSpec::make_string("default"); + test_case.add_input_spec(impl_selector_spec); + + // output + std::vector out_sizes; + if (is_batched) { + out_sizes = {config.B, config.M, config.N}; + } else { + out_sizes = {config.M, config.N}; + } + ValueSpec output( + out_sizes, dtype, storage_type, memory_layout, DataGenType::ZEROS); + test_case.add_output_spec(output); + + // Set tolerances - half precision needs wider tolerance + if (dtype == vkapi::kHalf) { + test_case.set_abs_tolerance(1e-1f); + test_case.set_rel_tolerance(1e-2f); + } else { + test_case.set_abs_tolerance(1e-3f); + test_case.set_rel_tolerance(1e-3f); + } + + // Filter out layout conversion shaders from timing + test_case.set_shader_filter({"nchw_to", "to_nchw", "view_copy"}); + + return test_case; +} + +// Reference implementation for mm/bmm +// Input layout per test operator: +// test_mm/test_bmm: mat1[0], mat2[1], impl_selector[2] +// test_addmm: self[0], mat1[1], mat2[2], beta[3], alpha[4], impl_selector[5] +// test_linear: input[0], weight[1], bias[2], impl_selector[3] +static void mm_reference_impl(TestCase& test_case) { + const std::string& op_name = test_case.operator_name(); + ValueSpec& output = test_case.outputs()[0]; + auto out_sizes = output.get_tensor_sizes(); + + if (test_case.inputs()[0].dtype != vkapi::kFloat) { + throw std::invalid_argument("Reference only supports float"); + } + + if (op_name == "test_etvk.test_mm.default") { + // mat1[0], mat2[1], impl_selector[2] + const auto& mat1 = test_case.inputs()[0]; + const auto& mat2 = test_case.inputs()[1]; + auto mat1_sizes = mat1.get_tensor_sizes(); + auto mat2_sizes = mat2.get_tensor_sizes(); + + int64_t M = mat1_sizes[0]; + int64_t K = mat1_sizes[1]; + int64_t N = mat2_sizes[1]; + + auto& mat1_data = mat1.get_float_data(); + auto& mat2_data = mat2.get_float_data(); + auto& ref_data = output.get_ref_float_data(); + ref_data.resize(M * N, 0.0f); + + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += mat1_data[m * K + k] * mat2_data[k * N + n]; + } + ref_data[m * N + n] = sum; + } + } + } else if (op_name == "test_etvk.test_bmm.default") { + // mat1[0], mat2[1], impl_selector[2] + const auto& mat1 = test_case.inputs()[0]; + const auto& mat2 = test_case.inputs()[1]; + auto mat1_sizes = mat1.get_tensor_sizes(); + auto mat2_sizes = mat2.get_tensor_sizes(); + + int64_t B = mat1_sizes[0]; + int64_t M = mat1_sizes[1]; + int64_t K = mat1_sizes[2]; + int64_t N = mat2_sizes[2]; + + auto& mat1_data = mat1.get_float_data(); + auto& mat2_data = mat2.get_float_data(); + auto& ref_data = output.get_ref_float_data(); + ref_data.resize(B * M * N, 0.0f); + + for (int64_t b = 0; b < B; ++b) { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += mat1_data[b * M * K + m * K + k] * + mat2_data[b * K * N + k * N + n]; + } + ref_data[b * M * N + m * N + n] = sum; + } + } + } + } else if (op_name == "test_etvk.test_addmm.default") { + // self[0], mat1[1], mat2[2], beta[3], alpha[4], impl_selector[5] + const auto& bias = test_case.inputs()[0]; + const auto& mat1 = test_case.inputs()[1]; + const auto& mat2 = test_case.inputs()[2]; + auto mat1_sizes = mat1.get_tensor_sizes(); + auto mat2_sizes = mat2.get_tensor_sizes(); + + int64_t M = mat1_sizes[0]; + int64_t K = mat1_sizes[1]; + int64_t N = mat2_sizes[1]; + + auto& bias_data = bias.get_float_data(); + auto& mat1_data = mat1.get_float_data(); + auto& mat2_data = mat2.get_float_data(); + auto& ref_data = output.get_ref_float_data(); + ref_data.resize(M * N, 0.0f); + + float alpha = test_case.inputs()[4].get_float_value(); + float beta = test_case.inputs()[3].get_float_value(); + + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += mat1_data[m * K + k] * mat2_data[k * N + n]; + } + float bias_val = + (n < static_cast(bias_data.size())) ? bias_data[n] : 0.0f; + ref_data[m * N + n] = beta * bias_val + alpha * sum; + } + } + } else if (op_name == "test_etvk.test_linear.default") { + // input[0], weight[1], bias[2], impl_selector[3] + const auto& input = test_case.inputs()[0]; + const auto& weight = test_case.inputs()[1]; + const auto& bias_spec = test_case.inputs()[2]; + auto input_sizes = input.get_tensor_sizes(); + auto weight_sizes = weight.get_tensor_sizes(); + + int64_t M = input_sizes[0]; + int64_t K = input_sizes[1]; + int64_t N = weight_sizes[0]; + + auto& input_data = input.get_float_data(); + auto& weight_data = weight.get_float_data(); + auto& ref_data = output.get_ref_float_data(); + ref_data.resize(M * N, 0.0f); + + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += input_data[m * K + k] * weight_data[n * K + k]; + } + if (!bias_spec.is_none()) { + auto& bias_data = bias_spec.get_float_data(); + sum += bias_data[n]; + } + ref_data[m * N + n] = sum; + } + } + } +} + +static std::vector generate_mm_test_cases() { + std::vector test_cases; + + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + std::vector shapes = { + // Accuracy shapes (float) + {0, 64, 128, 64}, + {0, 128, 256, 128}, + {0, 32, 64, 256}, + {1, 64, 128, 64}, + {1, 4, 32, 16}, + // Non-multiple-of-4 accuracy shapes (exercises scalar shader fallback) + {0, 57, 131, 43}, + {0, 33, 67, 91}, + {1, 19, 53, 37}, + {0, 64, 128, 47}, // only N unaligned + {0, 64, 47, 128}, // only K unaligned + // Performance shapes (half) + {0, 4096, 1024, 256}, + {0, 4096, 256, 128}, + {0, 4096, 128, 256}, + {1, 4096, 256, 1024}, + {1, 4096, 256, 128}, + {1, 256, 4096, 64}, + {0, 4096, 64, 128}, + }; + + for (const auto& s : shapes) { + bool is_batched = s.B > 0; + bool is_perf = s.M > kRefDimSizeLimit || s.K > kRefDimSizeLimit || + s.N > kRefDimSizeLimit; + + std::vector dtypes = is_perf + ? std::vector{vkapi::kFloat, vkapi::kHalf} + : std::vector{vkapi::kFloat}; + + MmConfig dynamic_cfg{s.B, s.M, s.K, s.N, false, false, false}; + MmConfig const_cfg{s.B, s.M, s.K, s.N, false, false, true}; + + for (auto dtype : dtypes) { + for (auto st : storage_types) { + test_cases.push_back( + create_mm_test_case(dynamic_cfg, dtype, st, utils::kWidthPacked)); + test_cases.push_back( + create_mm_test_case(const_cfg, dtype, st, utils::kWidthPacked)); + } + + if (!is_batched) { + MmConfig addmm_cfg{s.B, s.M, s.K, s.N, true, false, false}; + MmConfig addmm_const_cfg{s.B, s.M, s.K, s.N, true, false, true}; + MmConfig linear_cfg{s.B, s.M, s.K, s.N, false, true, false}; + MmConfig linear_bias_cfg{s.B, s.M, s.K, s.N, true, true, false}; + + for (auto st : storage_types) { + test_cases.push_back( + create_mm_test_case(addmm_cfg, dtype, st, utils::kWidthPacked)); + test_cases.push_back(create_mm_test_case( + addmm_const_cfg, dtype, st, utils::kWidthPacked)); + test_cases.push_back( + create_mm_test_case(linear_cfg, dtype, st, utils::kWidthPacked)); + test_cases.push_back(create_mm_test_case( + linear_bias_cfg, dtype, st, utils::kWidthPacked)); + } + } + } + } + + return test_cases; +} + +static int64_t mm_flop_calculator(const TestCase& test_case) { + const auto& out_sizes = test_case.outputs()[0].get_tensor_sizes(); + const std::string& op_name = test_case.operator_name(); + + int64_t M, N, K; + + if (op_name == "test_etvk.test_mm.default") { + // mat1[0], mat2[1], impl_selector[2] + auto mat1_sizes = test_case.inputs()[0].get_tensor_sizes(); + M = mat1_sizes[0]; + K = mat1_sizes[1]; + N = out_sizes[1]; + return 2 * M * K * N; + } else if (op_name == "test_etvk.test_bmm.default") { + // mat1[0], mat2[1], impl_selector[2] + auto mat1_sizes = test_case.inputs()[0].get_tensor_sizes(); + int64_t B = mat1_sizes[0]; + M = mat1_sizes[1]; + K = mat1_sizes[2]; + N = out_sizes[2]; + return 2 * B * M * K * N; + } else if (op_name == "test_etvk.test_addmm.default") { + // self[0], mat1[1], mat2[2], beta[3], alpha[4], impl_selector[5] + auto mat1_sizes = test_case.inputs()[1].get_tensor_sizes(); + M = mat1_sizes[0]; + K = mat1_sizes[1]; + N = out_sizes[1]; + return 2 * M * K * N; + } else if (op_name == "test_etvk.test_linear.default") { + // input[0], weight[1], bias[2], impl_selector[3] + auto input_sizes = test_case.inputs()[0].get_tensor_sizes(); + auto weight_sizes = test_case.inputs()[1].get_tensor_sizes(); + M = input_sizes[0]; + K = input_sizes[1]; + N = weight_sizes[0]; + return 2 * M * K * N; + } + return 1; +} + +static void reference_impl(TestCase& test_case) { + mm_reference_impl(test_case); +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Matrix Multiply (mm/bmm/addmm/linear) Benchmark" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( + generate_mm_test_cases, mm_flop_calculator, "MatMul", 3, 10, ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 6a9db70adaa..1ecf1c677ed 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -122,63 +122,102 @@ def get_binary_elementwise_compare_inputs(): @register_test_suite("aten.mm.default") def get_mm_inputs(): - test_suite = VkTestSuite( - [ - ((M1, L), (L, M2)), - ((S1, S2), (S2, M)), - ((6, 32), (32, 64)), - ], - ) - test_suite.prepacked_args = ["mat2"] - # ATen matmul doesn't support half - test_suite.dtypes = ["at::kFloat"] - test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] - test_suite.layouts = [ + test_cases = [ + ((M1, L), (L, M2)), + ((S1, S2), (S2, M)), + ((6, 32), (32, 64)), + ((XS, S1), (S1, XS)), + ((S, M1), (M1, S2)), + ((M2, S), (S, L)), + ((1, S2), (S2, M1)), + ((M, 1), (1, S1)), + ] + + # Prepacked mat2 exercises the linear code path + prepacked_suite = VkTestSuite(test_cases) + prepacked_suite.prepacked_args = ["mat2"] + prepacked_suite.dtypes = ["at::kFloat"] + prepacked_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] + prepacked_suite.layouts = [ "utils::kWidthPacked", - "utils::kChannelsPacked", ] - return test_suite + prepacked_suite.test_name_suffix = "prepacked" + + # Non-prepacked mat2 exercises the matmul code path + dynamic_suite = VkTestSuite(test_cases) + dynamic_suite.dtypes = ["at::kFloat"] + dynamic_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] + dynamic_suite.layouts = [ + "utils::kWidthPacked", + ] + dynamic_suite.test_name_suffix = "dynamic" + + return [prepacked_suite, dynamic_suite] @register_test_suite("aten.bmm.default") def get_bmm_inputs(): - test_suite = VkTestSuite( - [ - ((S, M1, L), (S, L, M2)), - ((M, S1, S2), (M, S2, M)), - ((4, 6, 32), (4, 32, 16)), - ], - ) - test_suite.prepacked_args = ["mat2"] - # ATen matmul doesn't support half - test_suite.dtypes = ["at::kFloat"] - test_suite.layouts = [ + test_cases = [ + ((S, M1, L), (S, L, M2)), + ((M, S1, S2), (M, S2, M)), + ((4, 6, 32), (4, 32, 16)), + ((XS, S, M1), (XS, M1, S2)), + ((1, M2, S1), (1, S1, S)), + ((S1, XS, M), (S1, M, S2)), + ((2, S2, S), (2, S, M1)), + ((XS, 1, S1), (XS, S1, 1)), + ] + + # Prepacked mat2 exercises the linear code path + prepacked_suite = VkTestSuite(test_cases) + prepacked_suite.prepacked_args = ["mat2"] + prepacked_suite.dtypes = ["at::kFloat"] + prepacked_suite.layouts = [ "utils::kWidthPacked", - "utils::kChannelsPacked", ] - return test_suite + prepacked_suite.test_name_suffix = "prepacked" + + # Non-prepacked mat2 exercises the matmul code path + dynamic_suite = VkTestSuite(test_cases) + dynamic_suite.dtypes = ["at::kFloat"] + dynamic_suite.layouts = [ + "utils::kWidthPacked", + ] + dynamic_suite.test_name_suffix = "dynamic" + + return [prepacked_suite, dynamic_suite] @register_test_suite("aten.addmm.default") def get_addmm_inputs(): - test_suite = VkTestSuite( - [ - ((1, S), (S1, S), (S, S), 1.0, 1.5), - ((S, 1), (S, S1), (S1, S1), 1.0, 1.0), - ((M1, M2), (M1, M2), (M2, M2)), - ((M1, M2), (M1, M2), (M2, M2), 4.2, 2.3), - ((M1, 1), (M1, L), (L, L), 2.0, 3.0), - ((M2), (M1, M2), (M2, M2)), - ((6, M2), (6, M2), (M2, M2)), - ] - ) - # ATen matmul doesn't support half - test_suite.dtypes = ["at::kFloat"] - test_suite.layouts = [ + test_cases = [ + ((1, S), (S1, S), (S, S), 1.0, 1.5), + ((S, 1), (S, S1), (S1, S1), 1.0, 1.0), + ((M1, M2), (M1, M2), (M2, M2)), + ((M1, M2), (M1, M2), (M2, M2), 4.2, 2.3), + ((M1, 1), (M1, L), (L, L), 2.0, 3.0), + ((M2), (M1, M2), (M2, M2)), + ((6, M2), (6, M2), (M2, M2)), + ] + + # Non-prepacked mat2 exercises the matmul addmm code path + dynamic_suite = VkTestSuite(test_cases) + dynamic_suite.dtypes = ["at::kFloat"] + dynamic_suite.layouts = [ "utils::kWidthPacked", - "utils::kChannelsPacked", ] - return test_suite + dynamic_suite.test_name_suffix = "dynamic" + + # Prepacked mat2 exercises the linear code path + prepacked_suite = VkTestSuite(test_cases) + prepacked_suite.prepacked_args = ["mat2"] + prepacked_suite.dtypes = ["at::kFloat"] + prepacked_suite.layouts = [ + "utils::kWidthPacked", + ] + prepacked_suite.test_name_suffix = "prepacked" + + return [dynamic_suite, prepacked_suite] common_MKN_list = [ @@ -201,7 +240,6 @@ def get_linear_inputs(): test_suite.dtypes = ["at::kFloat"] test_suite.layouts = [ "utils::kWidthPacked", - "utils::kChannelsPacked", ] test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] return test_suite