Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
#include "gemm_kunlun.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_xblas.h"
#include "gemm_kunlun_cast.h"

namespace op::gemm::kunlun {

typedef device::kunlun::blas::Handle::Internal HandleInternal;

static size_t matrixStorageSize(const BlasMatrix &matrix, size_t batch) {
auto batch_offset = matrix.stride == 0 ? 0 : static_cast<ptrdiff_t>(batch - 1) * matrix.stride;
auto last_offset = static_cast<ptrdiff_t>(matrix.rows - 1) * matrix.row_stride
+ static_cast<ptrdiff_t>(matrix.cols - 1) * matrix.col_stride;
return static_cast<size_t>(batch_offset + last_offset + 1);
}

static size_t bf16WorkspaceSize(const MatmulInfo &info) {
constexpr size_t f16_size = 2;
return (matrixStorageSize(info.a_matrix, info.batch)
+ matrixStorageSize(info.b_matrix, info.batch)
+ matrixStorageSize(info.c_matrix, info.batch))
* f16_size;
}

struct Descriptor::Opaque {
std::shared_ptr<HandleInternal> internal;
};
Expand All @@ -27,9 +43,11 @@ infiniStatus_t Descriptor::create(

auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
auto info = result.take();
auto workspace_size = dtype == INFINI_DTYPE_BF16 ? bf16WorkspaceSize(info) : 0;

*desc_ptr = new Descriptor(
dtype, result.take(), 0,
dtype, info, workspace_size,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
Expand Down Expand Up @@ -72,6 +90,57 @@ infiniStatus_t Descriptor::calculate(
auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;

if (_dtype == INFINI_DTYPE_BF16) {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

auto workspace_bytes = reinterpret_cast<char *>(workspace);
auto a_tmp = workspace_bytes;
auto b_tmp = a_tmp + matrixStorageSize(_info.a_matrix, _info.batch) * 2;
auto c_tmp = b_tmp + matrixStorageSize(_info.b_matrix, _info.batch) * 2;
auto temp_type = CUDA_R_16F;

CHECK_STATUS(castBf16ToF16(a, a_tmp, _info.a_matrix, _info.batch, (kunlunStream_t)stream));
CHECK_STATUS(castBf16ToF16(b, b_tmp, _info.b_matrix, _info.batch, (kunlunStream_t)stream));
CHECK_STATUS(castBf16ToF16(c, c_tmp, _info.c_matrix, _info.batch, (kunlunStream_t)stream));

CHECK_STATUS(_opaque->internal->useCublas(
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
CHECK_CUBLAS(
cublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(_info.m),
static_cast<int>(_info.n),
static_cast<int>(_info.k),
&alpha,
a_tmp,
temp_type,
static_cast<int>(_info.a_matrix.ld()),
_info.a_matrix.stride,
b_tmp,
temp_type,
static_cast<int>(_info.b_matrix.ld()),
_info.b_matrix.stride,
&beta,
c_tmp,
temp_type,
static_cast<int>(_info.c_matrix.ld()),
_info.c_matrix.stride,
static_cast<int>(_info.batch),
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
return INFINI_STATUS_SUCCESS;
}));

CHECK_STATUS(castF16ToBf16(c_tmp, c, _info.c_matrix, _info.batch, (kunlunStream_t)stream));
xpu_wait(stream);
return INFINI_STATUS_SUCCESS;
}

CHECK_STATUS(_opaque->internal->useCublas(
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
Expand Down
25 changes: 25 additions & 0 deletions src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef __GEMM_KUNLUN_CAST_H__
#define __GEMM_KUNLUN_CAST_H__

#include "../../../devices/kunlun/kunlun_common.h"
#include "../info.h"

namespace op::gemm::kunlun {

infiniStatus_t castBf16ToF16(
const void *src,
void *dst,
const BlasMatrix &matrix,
size_t batch,
kunlunStream_t stream);

infiniStatus_t castF16ToBf16(
const void *src,
void *dst,
const BlasMatrix &matrix,
size_t batch,
kunlunStream_t stream);

} // namespace op::gemm::kunlun

#endif // __GEMM_KUNLUN_CAST_H__
91 changes: 91 additions & 0 deletions src/infiniop/ops/gemm/kunlun/gemm_kunlun_cast.xpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "gemm_kunlun_cast.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"

namespace op::gemm::kunlun {

template <typename Tin, typename Tout>
__global__ void castMatrixKernel(
const Tin *src,
Tout *dst,
long long batch,
long long rows,
long long cols,
long long stride,
long long row_stride,
long long col_stride) {

int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int tid = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();

long long matrix_elements = rows * cols;
long long total = batch * matrix_elements;

for (long long idx = tid; idx < total; idx += nthreads) {
long long batch_id = idx / matrix_elements;
long long rem = idx - batch_id * matrix_elements;
long long row = rem / cols;
long long col = rem - row * cols;
long long batch_offset = stride == 0 ? 0 : batch_id * stride;
long long offset = batch_offset + row * row_stride + col * col_stride;

Tin in;
Tout out;
GM2LM_ASYNC(src + offset, &in, sizeof(Tin));
mfence();
if constexpr (xpu_std::is_same<Tin, bfloat16_t>::value) {
out = __float2half(__bfloat162float(in));
} else {
out = __float2bfloat16(__half2float(in));
}
LM2GM_ASYNC(&out, dst + offset, sizeof(Tout));
mfence();
}
sync_cluster();
}

infiniStatus_t castBf16ToF16(
const void *src,
void *dst,
const BlasMatrix &matrix,
size_t batch,
kunlunStream_t stream) {

castMatrixKernel<bfloat16_t, half>
<<<12, 64, stream>>>(
reinterpret_cast<const bfloat16_t *>(src),
reinterpret_cast<half *>(dst),
static_cast<long long>(batch),
static_cast<long long>(matrix.rows),
static_cast<long long>(matrix.cols),
static_cast<long long>(matrix.stride),
static_cast<long long>(matrix.row_stride),
static_cast<long long>(matrix.col_stride));
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t castF16ToBf16(
const void *src,
void *dst,
const BlasMatrix &matrix,
size_t batch,
kunlunStream_t stream) {

castMatrixKernel<half, bfloat16_t>
<<<12, 64, stream>>>(
reinterpret_cast<const half *>(src),
reinterpret_cast<bfloat16_t *>(dst),
static_cast<long long>(batch),
static_cast<long long>(matrix.rows),
static_cast<long long>(matrix.cols),
static_cast<long long>(matrix.stride),
static_cast<long long>(matrix.row_stride),
static_cast<long long>(matrix.col_stride));
return INFINI_STATUS_SUCCESS;
}

} // namespace op::gemm::kunlun
Loading