From 6ee1379a1cc42074b0601d57d73558541674c25f Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 12 Jun 2026 16:01:43 +0800 Subject: [PATCH] issue/1259 - patch kunlun gemm bf16 --- src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc | 71 ++++++++++++++- .../ops/gemm/kunlun/gemm_kunlun_cast.h | 25 +++++ .../ops/gemm/kunlun/gemm_kunlun_cast.xpu | 91 +++++++++++++++++++ 3 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.h create mode 100644 src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.xpu diff --git a/src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc b/src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc index b75f19fcf..e5038f44a 100644 --- a/src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc +++ b/src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc @@ -1,11 +1,27 @@ #include "gemm_kunlun.h" #include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_xblas.h" +#include "gemm_kunlun_cast.h" namespace op::gemm::kunlun { typedef device::kunlun::blas::Handle::Internal HandleInternal; +static size_t matrixStorageSize(const BlasMatrix &matrix, size_t batch) { + auto batch_offset = matrix.stride == 0 ? 0 : static_cast(batch - 1) * matrix.stride; + auto last_offset = static_cast(matrix.rows - 1) * matrix.row_stride + + static_cast(matrix.cols - 1) * matrix.col_stride; + return static_cast(batch_offset + last_offset + 1); +} + +static size_t bf16WorkspaceSize(const MatmulInfo &info) { + constexpr size_t f16_size = 2; + return (matrixStorageSize(info.a_matrix, info.batch) + + matrixStorageSize(info.b_matrix, info.batch) + + matrixStorageSize(info.c_matrix, info.batch)) + * f16_size; +} + struct Descriptor::Opaque { std::shared_ptr internal; }; @@ -27,9 +43,11 @@ infiniStatus_t Descriptor::create( auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR); CHECK_RESULT(result); + auto info = result.take(); + auto workspace_size = dtype == INFINI_DTYPE_BF16 ? bf16WorkspaceSize(info) : 0; *desc_ptr = new Descriptor( - dtype, result.take(), 0, + dtype, info, workspace_size, new Opaque{handle->internal()}, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; @@ -72,6 +90,57 @@ infiniStatus_t Descriptor::calculate( auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; + if (_dtype == INFINI_DTYPE_BF16) { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto workspace_bytes = reinterpret_cast(workspace); + auto a_tmp = workspace_bytes; + auto b_tmp = a_tmp + matrixStorageSize(_info.a_matrix, _info.batch) * 2; + auto c_tmp = b_tmp + matrixStorageSize(_info.b_matrix, _info.batch) * 2; + auto temp_type = CUDA_R_16F; + + CHECK_STATUS(castBf16ToF16(a, a_tmp, _info.a_matrix, _info.batch, (kunlunStream_t)stream)); + CHECK_STATUS(castBf16ToF16(b, b_tmp, _info.b_matrix, _info.batch, (kunlunStream_t)stream)); + CHECK_STATUS(castBf16ToF16(c, c_tmp, _info.c_matrix, _info.batch, (kunlunStream_t)stream)); + + CHECK_STATUS(_opaque->internal->useCublas( + (cudaStream_t)stream, + [&](cublasHandle_t handle) { + CHECK_CUBLAS( + cublasGemmStridedBatchedEx( + handle, + op_a, + op_b, + static_cast(_info.m), + static_cast(_info.n), + static_cast(_info.k), + &alpha, + a_tmp, + temp_type, + static_cast(_info.a_matrix.ld()), + _info.a_matrix.stride, + b_tmp, + temp_type, + static_cast(_info.b_matrix.ld()), + _info.b_matrix.stride, + &beta, + c_tmp, + temp_type, + static_cast(_info.c_matrix.ld()), + _info.c_matrix.stride, + static_cast(_info.batch), + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + return INFINI_STATUS_SUCCESS; + })); + + CHECK_STATUS(castF16ToBf16(c_tmp, c, _info.c_matrix, _info.batch, (kunlunStream_t)stream)); + xpu_wait(stream); + return INFINI_STATUS_SUCCESS; + } + CHECK_STATUS(_opaque->internal->useCublas( (cudaStream_t)stream, [&](cublasHandle_t handle) { diff --git a/src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.h b/src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.h new file mode 100644 index 000000000..b5dfd7ea7 --- /dev/null +++ b/src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.h @@ -0,0 +1,25 @@ +#ifndef __GEMM_KUNLUN_CAST_H__ +#define __GEMM_KUNLUN_CAST_H__ + +#include "../../../devices/kunlun/kunlun_common.h" +#include "../info.h" + +namespace op::gemm::kunlun { + +infiniStatus_t castBf16ToF16( + const void *src, + void *dst, + const BlasMatrix &matrix, + size_t batch, + kunlunStream_t stream); + +infiniStatus_t castF16ToBf16( + const void *src, + void *dst, + const BlasMatrix &matrix, + size_t batch, + kunlunStream_t stream); + +} // namespace op::gemm::kunlun + +#endif // __GEMM_KUNLUN_CAST_H__ diff --git a/src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.xpu b/src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.xpu new file mode 100644 index 000000000..af91ff052 --- /dev/null +++ b/src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.xpu @@ -0,0 +1,91 @@ +#include "gemm_kunlun_cast.h" +#include "../../../devices/kunlun/kunlun_kernel_common.h" + +namespace op::gemm::kunlun { + +template +__global__ void castMatrixKernel( + const Tin *src, + Tout *dst, + long long batch, + long long rows, + long long cols, + long long stride, + long long row_stride, + long long col_stride) { + + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int tid = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + + long long matrix_elements = rows * cols; + long long total = batch * matrix_elements; + + for (long long idx = tid; idx < total; idx += nthreads) { + long long batch_id = idx / matrix_elements; + long long rem = idx - batch_id * matrix_elements; + long long row = rem / cols; + long long col = rem - row * cols; + long long batch_offset = stride == 0 ? 0 : batch_id * stride; + long long offset = batch_offset + row * row_stride + col * col_stride; + + Tin in; + Tout out; + GM2LM_ASYNC(src + offset, &in, sizeof(Tin)); + mfence(); + if constexpr (xpu_std::is_same::value) { + out = __float2half(__bfloat162float(in)); + } else { + out = __float2bfloat16(__half2float(in)); + } + LM2GM_ASYNC(&out, dst + offset, sizeof(Tout)); + mfence(); + } + sync_cluster(); +} + +infiniStatus_t castBf16ToF16( + const void *src, + void *dst, + const BlasMatrix &matrix, + size_t batch, + kunlunStream_t stream) { + + castMatrixKernel + <<<12, 64, stream>>>( + reinterpret_cast(src), + reinterpret_cast(dst), + static_cast(batch), + static_cast(matrix.rows), + static_cast(matrix.cols), + static_cast(matrix.stride), + static_cast(matrix.row_stride), + static_cast(matrix.col_stride)); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t castF16ToBf16( + const void *src, + void *dst, + const BlasMatrix &matrix, + size_t batch, + kunlunStream_t stream) { + + castMatrixKernel + <<<12, 64, stream>>>( + reinterpret_cast(src), + reinterpret_cast(dst), + static_cast(batch), + static_cast(matrix.rows), + static_cast(matrix.cols), + static_cast(matrix.stride), + static_cast(matrix.row_stride), + static_cast(matrix.col_stride)); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gemm::kunlun