Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 55 additions & 39 deletions src/infiniop/ops/rope/ascend/rope_ascend_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ class RoPEKernelNeox {
size_t _tile_len;
size_t _copy_len;
size_t _half_len;
size_t _half_copy_len;
size_t _batch;
size_t _nhead;

Expand Down Expand Up @@ -399,6 +400,7 @@ __aicore__ inline void RoPEKernelNeox<T, U>::init(GM_ADDR y,
this->_st_xnh = st_xnh;
this->_st_xbatch = st_xbatch;
_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_half_copy_len = alignTileLen<T>(_half_len, BYTE_ALIGN);

_block_idx = GetBlockIdx();

Expand All @@ -410,27 +412,28 @@ __aicore__ inline void RoPEKernelNeox<T, U>::init(GM_ADDR y,

pipe.InitBuffer(_in_que, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(_out_que, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_len * sizeof(T));
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_len * sizeof(T));
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_copy_len * sizeof(T));

if constexpr (std::is_same<T, bfloat16_t>::value) {
size_t half_float_copy_len = alignTileLen<float>(_half_len, BYTE_ALIGN);
pipe.InitBuffer(_tmp_float_input, _copy_len * sizeof(float));
pipe.InitBuffer(_tmp_float_sin, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_float_cos, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_float_sin, half_float_copy_len * sizeof(float));
pipe.InitBuffer(_tmp_float_cos, half_float_copy_len * sizeof(float));
pipe.InitBuffer(_tmp_float_output, _tile_len * sizeof(float));
pipe.InitBuffer(_tmp_first_half, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_second_half, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_result1, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_result2, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_result3, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_result4, _half_len * sizeof(float));
pipe.InitBuffer(_tmp_first_half, half_float_copy_len * sizeof(float));
pipe.InitBuffer(_tmp_second_half, half_float_copy_len * sizeof(float));
pipe.InitBuffer(_tmp_result1, half_float_copy_len * sizeof(float));
pipe.InitBuffer(_tmp_result2, half_float_copy_len * sizeof(float));
pipe.InitBuffer(_tmp_result3, half_float_copy_len * sizeof(float));
pipe.InitBuffer(_tmp_result4, half_float_copy_len * sizeof(float));
} else {
pipe.InitBuffer(_tmp_first_half, _half_len * sizeof(T));
pipe.InitBuffer(_tmp_second_half, _half_len * sizeof(T));
pipe.InitBuffer(_tmp_result1, _half_len * sizeof(T));
pipe.InitBuffer(_tmp_result2, _half_len * sizeof(T));
pipe.InitBuffer(_tmp_result3, _half_len * sizeof(T));
pipe.InitBuffer(_tmp_result4, _half_len * sizeof(T));
pipe.InitBuffer(_tmp_first_half, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_second_half, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_result1, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_result2, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_result3, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_result4, _half_copy_len * sizeof(T));
}
}

Expand All @@ -446,8 +449,10 @@ __aicore__ inline void RoPEKernelNeox<T, U>::copyIn(size_t i) {
auto idx = batch_idx * _st_xbatch + i * _st_xnt + head_idx * _st_xnh;
DataCopy(input_ub, _x_gm[idx], _copy_len);
auto pos_idx = _p_gm(i);
DataCopy(sin_ub, _sin_gm[pos_idx * _half_len], _half_len);
DataCopy(cos_ub, _cos_gm[pos_idx * _half_len], _half_len);
DataCopyExtParams halfCopyParams = {1, static_cast<uint32_t>(_half_len * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams<T> halfPadParams{true, 0, 0, 0};
DataCopyPad(sin_ub, _sin_gm[pos_idx * _half_len], halfCopyParams, halfPadParams);
DataCopyPad(cos_ub, _cos_gm[pos_idx * _half_len], halfCopyParams, halfPadParams);
_in_que.EnQue(input_ub);
_sin_que.EnQue(sin_ub);
_cos_que.EnQue(cos_ub);
Expand All @@ -470,26 +475,32 @@ __aicore__ inline void RoPEKernelNeox<T, U>::compute(size_t i) {
LocalTensor<float> result2_f = _tmp_result2.Get<float>();
LocalTensor<float> result3_f = _tmp_result3.Get<float>();
LocalTensor<float> result4_f = _tmp_result4.Get<float>();
size_t half_float_copy_len = alignTileLen<float>(_half_len, BYTE_ALIGN);

Cast(input_f, input_ub, AscendC::RoundMode::CAST_NONE, _copy_len);
Cast(sin_f, sin_ub, AscendC::RoundMode::CAST_NONE, _half_len);
Cast(cos_f, cos_ub, AscendC::RoundMode::CAST_NONE, _half_len);

for (size_t j = 0; j < _half_len; j++) {
first_half_f(j) = input_f(j);
second_half_f(j) = input_f(_half_len + j);
Cast(sin_f, sin_ub, AscendC::RoundMode::CAST_NONE, half_float_copy_len);
Cast(cos_f, cos_ub, AscendC::RoundMode::CAST_NONE, half_float_copy_len);

for (size_t j = 0; j < half_float_copy_len; j++) {
if (j < _half_len) {
first_half_f(j) = input_f(j);
second_half_f(j) = input_f(_half_len + j);
} else {
first_half_f(j) = 0.0f;
second_half_f(j) = 0.0f;
}
}
PipeBarrier<PIPE_V>();

Mul<float>(result1_f, first_half_f, cos_f, _half_len);
Mul<float>(result2_f, second_half_f, sin_f, _half_len);
Mul<float>(result1_f, first_half_f, cos_f, half_float_copy_len);
Mul<float>(result2_f, second_half_f, sin_f, half_float_copy_len);
PipeBarrier<PIPE_V>();
Sub<float>(result3_f, result1_f, result2_f, _half_len);
Sub<float>(result3_f, result1_f, result2_f, half_float_copy_len);

Mul<float>(result1_f, first_half_f, sin_f, _half_len);
Mul<float>(result2_f, second_half_f, cos_f, _half_len);
Mul<float>(result1_f, first_half_f, sin_f, half_float_copy_len);
Mul<float>(result2_f, second_half_f, cos_f, half_float_copy_len);
PipeBarrier<PIPE_V>();
Add<float>(result4_f, result1_f, result2_f, _half_len);
Add<float>(result4_f, result1_f, result2_f, half_float_copy_len);

LocalTensor<float> output_f = _tmp_float_output.Get<float>();
for (size_t j = 0; j < _half_len; j++) {
Expand All @@ -507,21 +518,26 @@ __aicore__ inline void RoPEKernelNeox<T, U>::compute(size_t i) {
LocalTensor<T> result3 = _tmp_result3.Get<T>();
LocalTensor<T> result4 = _tmp_result4.Get<T>();

for (size_t j = 0; j < _half_len; j++) {
first_half(j) = input_ub(j);
second_half(j) = input_ub(_half_len + j);
for (size_t j = 0; j < _half_copy_len; j++) {
if (j < _half_len) {
first_half(j) = input_ub(j);
second_half(j) = input_ub(_half_len + j);
} else {
first_half(j) = static_cast<T>(0);
second_half(j) = static_cast<T>(0);
}
}
PipeBarrier<PIPE_V>();

Mul<T>(result1, first_half, cos_ub, _half_len);
Mul<T>(result2, second_half, sin_ub, _half_len);
Mul<T>(result1, first_half, cos_ub, _half_copy_len);
Mul<T>(result2, second_half, sin_ub, _half_copy_len);
PipeBarrier<PIPE_V>();
Sub<T>(result3, result1, result2, _half_len);
Sub<T>(result3, result1, result2, _half_copy_len);

Mul<T>(result1, first_half, sin_ub, _half_len);
Mul<T>(result2, second_half, cos_ub, _half_len);
Mul<T>(result1, first_half, sin_ub, _half_copy_len);
Mul<T>(result2, second_half, cos_ub, _half_copy_len);
PipeBarrier<PIPE_V>();
Add<T>(result4, result1, result2, _half_len);
Add<T>(result4, result1, result2, _half_copy_len);

for (size_t j = 0; j < _half_len; j++) {
output_ub(j) = result3(j);
Expand Down
Loading