Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion include/infinicore/adaptor/aten_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
#include <c10/cuda/CUDAStream.h>
#endif

#if defined(ENABLE_CAMBRICON_API)
#include <framework/core/MLUStream.h>
#include <framework/core/stream_guard.h>
#endif

namespace infinicore::adaptor {
inline at::ScalarType to_at_dtype(DataType dtype) {
switch (dtype) {
Expand All @@ -31,9 +36,11 @@ inline at::ScalarType to_at_dtype(DataType dtype) {

inline at::Device to_at_device(const Device &device) {
// PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA).
// Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability.
// Treat MetaX/QY devices as CUDA devices and Cambricon MLU as PrivateUse1 for ATen interoperability.
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) {
return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CAMBRICON) {
return at::Device(c10::DeviceType::PrivateUse1, device.getIndex());
} else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU);
} else {
Expand All @@ -46,6 +53,10 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t);
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
c10::cuda::CUDAStream get_cuda_stream();
#endif

#if defined(ENABLE_CAMBRICON_API)
torch_mlu::MLUStream get_mlu_stream();
#endif
} // namespace infinicore::adaptor

#endif // ENABLE_ATEN
19 changes: 12 additions & 7 deletions include/infinicore/adaptor/flash_attention_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#pragma once
#include "aten_adaptor.hpp"

// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension
// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds
// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip MetaX/Cambricon extensions
// export the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds
// where the namespace is empty.
#if !defined(ENABLE_METAX_API)
#if !defined(ENABLE_METAX_API) && !defined(ENABLE_CAMBRICON_API)
namespace flash {
#endif
std::vector<at::Tensor>
Expand All @@ -30,9 +30,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
#if defined(ENABLE_CAMBRICON_API)
at::Tensor &k, // Cambricon flash_attn_2_bang exports non-const k/v references.
at::Tensor &v,
#else
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
#endif
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
Expand Down Expand Up @@ -133,7 +138,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
#endif
);

#if !defined(ENABLE_METAX_API)
#if !defined(ENABLE_METAX_API) && !defined(ENABLE_CAMBRICON_API)
} // namespace flash
#endif
#endif // ENABLE_FLASH_ATTN
19 changes: 17 additions & 2 deletions python/infinicore/_preload.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,28 @@ def preload_hpcc() -> None:
_try_load(prefixes, lib)


def preload_cambricon() -> None:
"""Best-effort import of torch MLU so torch resolves its own shared libraries."""
try:
import torch # noqa: F401
except Exception:
return

try:
import torch_mlu # noqa: F401
except Exception:
pass


def _should_preload_device(device_type: str) -> bool:
"""
Check if preload is needed for a specific device type.
"""
device_env_map = {
"METAX": ["HPCC_PATH", "INFINICORE_PRELOAD_HPCC"], # HPCC/METAX
"CAMBRICON": ["NEUWARE_HOME", "INFINICORE_PRELOAD_CAMBRICON"],
# Add other device types here as needed:
# "ASCEND": ["ASCEND_PATH"],
# "CAMBRICON": ["NEUWARE_HOME"],
}

env_vars = device_env_map.get(device_type, [])
Expand All @@ -90,6 +103,8 @@ def preload_device(device_type: str) -> None:
"""
if device_type == "METAX":
preload_hpcc()
elif device_type == "CAMBRICON":
preload_cambricon()
# Add other device preload functions here as needed:
# elif device_type == "ASCEND":
# preload_ascend()
Expand All @@ -106,9 +121,9 @@ def preload() -> None:
# Device types that may require preload
device_types = [
"METAX", # HPCC/METAX
"CAMBRICON",
# Add other device types here as they are implemented:
# "ASCEND",
# "CAMBRICON",
# etc.
]

Expand Down
7 changes: 7 additions & 0 deletions src/infinicore/adaptor/aten_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ c10::cuda::CUDAStream get_cuda_stream() {
}
#endif

#if defined(ENABLE_CAMBRICON_API)
torch_mlu::MLUStream get_mlu_stream() {
return torch_mlu::getStreamFromExternal(
cnrtQueue_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
#endif

} // namespace infinicore::adaptor

#endif // ENABLE_ATEN
100 changes: 95 additions & 5 deletions src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@

#include <stdexcept>

#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_CAMBRICON_API)
#include <algorithm>
#include <cstdint>
#include <vector>
#endif

#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#include <c10/cuda/CUDAGuard.h>
#endif
#if defined(ENABLE_CAMBRICON_API)
#include <framework/core/stream_guard.h>
#endif
#endif

#if defined(ENABLE_METAX_API)
#if defined(ENABLE_METAX_API) || defined(ENABLE_CAMBRICON_API)
#define INFINICORE_FLASH_OP(name) ::name
#else
#define INFINICORE_FLASH_OP(name) flash::name
Expand Down Expand Up @@ -45,31 +54,111 @@ void *plan(Tensor out,

void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_CAMBRICON_API)
#if defined(ENABLE_CAMBRICON_API)
torch_mlu::mlu::MLUStreamGuard guard(infinicore::adaptor::get_mlu_stream());
#else
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
#endif
#endif
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);

#if defined(ENABLE_CAMBRICON_API)
const bool out_need_copy_back = !p->out->is_contiguous();
Tensor out_work_ic = out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
auto out_work = infinicore::adaptor::to_aten_tensor(out_work_ic);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
auto seqlens_k_tensor = infinicore::adaptor::to_aten_tensor(p->seqlens_k);
auto block_table_tensor = infinicore::adaptor::to_aten_tensor(p->block_table);
auto alibi_slopes = p->alibi_slopes
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;

if (q.dim() != 4 || out_work.dim() != 4) {
throw std::runtime_error("Cambricon flash-attn KV-cache path expects q/out with shape [batch, seqlen_q, heads, head_dim]");
}

const auto batch_size = q.size(0);
const auto seqlen_q = q.size(1);
const auto num_heads = q.size(2);
const auto head_size = q.size(3);

auto seqlens_k_cpu = seqlens_k_tensor.to(at::kCPU);
auto seqlens_k_data = seqlens_k_cpu.data_ptr<int32_t>();
std::vector<int32_t> cu_seqlens_q_host(batch_size + 1, 0);
std::vector<int32_t> cu_seqlens_k_host(batch_size + 1, 0);
int32_t max_seqlen_k = 0;
for (int64_t i = 0; i < batch_size; ++i) {
cu_seqlens_q_host[i + 1] = cu_seqlens_q_host[i] + static_cast<int32_t>(seqlen_q);
cu_seqlens_k_host[i + 1] = cu_seqlens_k_host[i] + seqlens_k_data[i];
max_seqlen_k = std::max(max_seqlen_k, seqlens_k_data[i]);
}

auto tensor_options = q.options().dtype(at::kInt);
auto cu_seqlens_q = at::tensor(cu_seqlens_q_host, tensor_options);
auto cu_seqlens_k = at::tensor(cu_seqlens_k_host, tensor_options);

auto q_varlen = q.reshape({batch_size * seqlen_q, num_heads, head_size});
auto out_varlen = out_work.reshape({batch_size * seqlen_q, num_heads, head_size});
auto out = std::optional<at::Tensor>(out_varlen);
std::optional<at::Tensor> seqused_k = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
auto block_table = std::optional<at::Tensor>(block_table_tensor);

INFINICORE_FLASH_OP(mha_varlen_fwd)
(
q_varlen,
k_cache,
v_cache,
out,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,
static_cast<int>(seqlen_q),
static_cast<int>(max_seqlen_k),
0.0,
p->scale,
false,
seqlen_q > 1,
-1,
-1,
0.0,
false,
std::nullopt);

if (out_need_copy_back) {
p->out->copy_from(out_work_ic);
}
return;
#else
// Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense.
const bool out_need_copy_back = !p->out->is_contiguous();
Tensor out_work = out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
auto out_tensor = infinicore::adaptor::to_aten_tensor(out_work);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
#elif defined(ENABLE_QY_API)
Tensor k_cache_work = p->k_cache->contiguous();
Tensor v_cache_work = p->v_cache->contiguous();
auto k_cache = infinicore::adaptor::to_aten_tensor(k_cache_work);
auto v_cache = infinicore::adaptor::to_aten_tensor(v_cache_work);
#endif
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
#endif
auto alibi_slopes = p->alibi_slopes
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;

auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));

std::optional<const at::Tensor> k_new = std::nullopt;
std::optional<const at::Tensor> v_new = std::nullopt;
std::optional<const at::Tensor> rotary_cos = std::nullopt;
Expand Down Expand Up @@ -121,6 +210,7 @@ void run(void *planned_meta) {
if (out_need_copy_back) {
p->out->copy_from(out_work);
}
#endif // ENABLE_CAMBRICON_API
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
Expand Down
13 changes: 10 additions & 3 deletions src/infinicore/ops/multi_head_attention/mha_flashattn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#include <c10/cuda/CUDAGuard.h>
#endif
#if defined(ENABLE_CAMBRICON_API)
#include <framework/core/stream_guard.h>
#endif
#endif

namespace infinicore::op::mha_impl::flashattn {
Expand Down Expand Up @@ -40,10 +43,10 @@ void *plan(Tensor out,
namespace {

// Only support nv for now
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_NVIDIA_API)
#if defined(ENABLE_FLASH_ATTN) && (defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_CAMBRICON_API))
// MetaX/hpcc pip `flash_attn_2_cuda` exports `mha_fwd` at global scope (no namespace),
// while NVIDIA `flash-attn-nvidia.so` uses `flash::mha_fwd`.
#if defined(ENABLE_METAX_API)
#if defined(ENABLE_METAX_API) || defined(ENABLE_CAMBRICON_API)
#define INFINICORE_FLASH_OP(name) ::name
#else
#define INFINICORE_FLASH_OP(name) flash::name
Expand All @@ -54,8 +57,12 @@ namespace {

void run(void *planned_meta) {
// Only support nv for now
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_NVIDIA_API)
#if defined(ENABLE_FLASH_ATTN) && (defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_CAMBRICON_API))
#if defined(ENABLE_CAMBRICON_API)
torch_mlu::mlu::MLUStreamGuard guard(infinicore::adaptor::get_mlu_stream());
#else
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
#endif
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);

auto q = infinicore::adaptor::to_aten_tensor(p->q);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#include <c10/cuda/CUDAGuard.h>
#endif
#if defined(ENABLE_CAMBRICON_API)
#include <framework/core/stream_guard.h>
#endif
#endif

namespace infinicore::op::mha_varlen_impl::flashattn {
Expand Down Expand Up @@ -50,7 +53,7 @@ namespace {
#ifdef ENABLE_FLASH_ATTN
// MetaX/hpcc pip `flash_attn_2_cuda` exports `mha_varlen_fwd` at global scope (no namespace),
// while NVIDIA `flash-attn-nvidia.so` uses `flash::mha_varlen_fwd`.
#if defined(ENABLE_METAX_API)
#if defined(ENABLE_METAX_API) || defined(ENABLE_CAMBRICON_API)
#define INFINICORE_FLASH_OP(name) ::name
#else
#define INFINICORE_FLASH_OP(name) flash::name
Expand All @@ -61,7 +64,11 @@ namespace {

void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_CAMBRICON_API)
torch_mlu::mlu::MLUStreamGuard guard(infinicore::adaptor::get_mlu_stream());
#else
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
#endif
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);

auto q = infinicore::adaptor::to_aten_tensor(p->q);
Expand Down
Loading
Loading