From 6e690723599e0e332bc98bff05a56d70ecc8640b Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 11 Jun 2026 10:53:59 +0800 Subject: [PATCH] issue/1233 - cambricon fa2 --- include/infinicore/adaptor/aten_adaptor.hpp | 13 ++- .../adaptor/flash_attention_adaptor.hpp | 19 ++-- python/infinicore/_preload.py | 19 +++- src/infinicore/adaptor/aten_adaptor.cc | 7 ++ .../ops/mha_kvcache/mha_kvcache_flashattn.cc | 100 +++++++++++++++++- .../ops/multi_head_attention/mha_flashattn.cc | 13 ++- .../mha_varlen_flashattn.cc | 9 +- xmake.lua | 49 ++++++--- xmake/bang.lua | 78 ++++++++++++++ 9 files changed, 276 insertions(+), 31 deletions(-) diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index 00d5cbec2..4cec69fa3 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -11,6 +11,11 @@ #include #endif +#if defined(ENABLE_CAMBRICON_API) +#include +#include +#endif + namespace infinicore::adaptor { inline at::ScalarType to_at_dtype(DataType dtype) { switch (dtype) { @@ -31,9 +36,11 @@ inline at::ScalarType to_at_dtype(DataType dtype) { inline at::Device to_at_device(const Device &device) { // PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA). - // Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability. + // Treat MetaX/QY devices as CUDA devices and Cambricon MLU as PrivateUse1 for ATen interoperability. if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) { return at::Device(at::kCUDA, device.getIndex()); + } else if (device.getType() == Device::Type::CAMBRICON) { + return at::Device(c10::DeviceType::PrivateUse1, device.getIndex()); } else if (device.getType() == Device::Type::CPU) { return at::Device(at::kCPU); } else { @@ -46,6 +53,10 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t); #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStream get_cuda_stream(); #endif + +#if defined(ENABLE_CAMBRICON_API) +torch_mlu::MLUStream get_mlu_stream(); +#endif } // namespace infinicore::adaptor #endif // ENABLE_ATEN diff --git a/include/infinicore/adaptor/flash_attention_adaptor.hpp b/include/infinicore/adaptor/flash_attention_adaptor.hpp index c5bf14858..625eb0e98 100644 --- a/include/infinicore/adaptor/flash_attention_adaptor.hpp +++ b/include/infinicore/adaptor/flash_attention_adaptor.hpp @@ -2,10 +2,10 @@ #pragma once #include "aten_adaptor.hpp" -// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension -// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds +// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip MetaX/Cambricon extensions +// export the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds // where the namespace is empty. -#if !defined(ENABLE_METAX_API) +#if !defined(ENABLE_METAX_API) && !defined(ENABLE_CAMBRICON_API) namespace flash { #endif std::vector @@ -30,9 +30,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num ); std::vector -mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +#if defined(ENABLE_CAMBRICON_API) + at::Tensor &k, // Cambricon flash_attn_2_bang exports non-const k/v references. + at::Tensor &v, +#else + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. +#endif std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -133,7 +138,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size #endif ); -#if !defined(ENABLE_METAX_API) +#if !defined(ENABLE_METAX_API) && !defined(ENABLE_CAMBRICON_API) } // namespace flash #endif #endif // ENABLE_FLASH_ATTN diff --git a/python/infinicore/_preload.py b/python/infinicore/_preload.py index fc5ff6560..cf0ca83f4 100644 --- a/python/infinicore/_preload.py +++ b/python/infinicore/_preload.py @@ -63,15 +63,28 @@ def preload_hpcc() -> None: _try_load(prefixes, lib) +def preload_cambricon() -> None: + """Best-effort import of torch MLU so torch resolves its own shared libraries.""" + try: + import torch # noqa: F401 + except Exception: + return + + try: + import torch_mlu # noqa: F401 + except Exception: + pass + + def _should_preload_device(device_type: str) -> bool: """ Check if preload is needed for a specific device type. """ device_env_map = { "METAX": ["HPCC_PATH", "INFINICORE_PRELOAD_HPCC"], # HPCC/METAX + "CAMBRICON": ["NEUWARE_HOME", "INFINICORE_PRELOAD_CAMBRICON"], # Add other device types here as needed: # "ASCEND": ["ASCEND_PATH"], - # "CAMBRICON": ["NEUWARE_HOME"], } env_vars = device_env_map.get(device_type, []) @@ -90,6 +103,8 @@ def preload_device(device_type: str) -> None: """ if device_type == "METAX": preload_hpcc() + elif device_type == "CAMBRICON": + preload_cambricon() # Add other device preload functions here as needed: # elif device_type == "ASCEND": # preload_ascend() @@ -106,9 +121,9 @@ def preload() -> None: # Device types that may require preload device_types = [ "METAX", # HPCC/METAX + "CAMBRICON", # Add other device types here as they are implemented: # "ASCEND", - # "CAMBRICON", # etc. ] diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index 04db643f9..04b286af0 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -39,6 +39,13 @@ c10::cuda::CUDAStream get_cuda_stream() { } #endif +#if defined(ENABLE_CAMBRICON_API) +torch_mlu::MLUStream get_mlu_stream() { + return torch_mlu::getStreamFromExternal( + cnrtQueue_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); +} +#endif + } // namespace infinicore::adaptor #endif // ENABLE_ATEN diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index 0167c17df..8fa95daff 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -4,13 +4,22 @@ #include +#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_CAMBRICON_API) +#include +#include +#include +#endif + #ifdef ENABLE_FLASH_ATTN #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) #include #endif +#if defined(ENABLE_CAMBRICON_API) +#include +#endif #endif -#if defined(ENABLE_METAX_API) +#if defined(ENABLE_METAX_API) || defined(ENABLE_CAMBRICON_API) #define INFINICORE_FLASH_OP(name) ::name #else #define INFINICORE_FLASH_OP(name) flash::name @@ -45,11 +54,88 @@ void *plan(Tensor out, void run(void *planned_meta) { #ifdef ENABLE_FLASH_ATTN -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_CAMBRICON_API) +#if defined(ENABLE_CAMBRICON_API) + torch_mlu::mlu::MLUStreamGuard guard(infinicore::adaptor::get_mlu_stream()); +#else c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif #endif auto *p = reinterpret_cast(planned_meta); +#if defined(ENABLE_CAMBRICON_API) + const bool out_need_copy_back = !p->out->is_contiguous(); + Tensor out_work_ic = out_need_copy_back ? p->out->contiguous() : Tensor(p->out); + auto out_work = infinicore::adaptor::to_aten_tensor(out_work_ic); + auto q = infinicore::adaptor::to_aten_tensor(p->q); + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); + auto seqlens_k_tensor = infinicore::adaptor::to_aten_tensor(p->seqlens_k); + auto block_table_tensor = infinicore::adaptor::to_aten_tensor(p->block_table); + auto alibi_slopes = p->alibi_slopes + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) + : std::nullopt; + + if (q.dim() != 4 || out_work.dim() != 4) { + throw std::runtime_error("Cambricon flash-attn KV-cache path expects q/out with shape [batch, seqlen_q, heads, head_dim]"); + } + + const auto batch_size = q.size(0); + const auto seqlen_q = q.size(1); + const auto num_heads = q.size(2); + const auto head_size = q.size(3); + + auto seqlens_k_cpu = seqlens_k_tensor.to(at::kCPU); + auto seqlens_k_data = seqlens_k_cpu.data_ptr(); + std::vector cu_seqlens_q_host(batch_size + 1, 0); + std::vector cu_seqlens_k_host(batch_size + 1, 0); + int32_t max_seqlen_k = 0; + for (int64_t i = 0; i < batch_size; ++i) { + cu_seqlens_q_host[i + 1] = cu_seqlens_q_host[i] + static_cast(seqlen_q); + cu_seqlens_k_host[i + 1] = cu_seqlens_k_host[i] + seqlens_k_data[i]; + max_seqlen_k = std::max(max_seqlen_k, seqlens_k_data[i]); + } + + auto tensor_options = q.options().dtype(at::kInt); + auto cu_seqlens_q = at::tensor(cu_seqlens_q_host, tensor_options); + auto cu_seqlens_k = at::tensor(cu_seqlens_k_host, tensor_options); + + auto q_varlen = q.reshape({batch_size * seqlen_q, num_heads, head_size}); + auto out_varlen = out_work.reshape({batch_size * seqlen_q, num_heads, head_size}); + auto out = std::optional(out_varlen); + std::optional seqused_k = std::nullopt; + std::optional leftpad_k = std::nullopt; + auto block_table = std::optional(block_table_tensor); + + INFINICORE_FLASH_OP(mha_varlen_fwd) + ( + q_varlen, + k_cache, + v_cache, + out, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + block_table, + alibi_slopes, + static_cast(seqlen_q), + static_cast(max_seqlen_k), + 0.0, + p->scale, + false, + seqlen_q > 1, + -1, + -1, + 0.0, + false, + std::nullopt); + + if (out_need_copy_back) { + p->out->copy_from(out_work_ic); + } + return; +#else // Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense. const bool out_need_copy_back = !p->out->is_contiguous(); Tensor out_work = out_need_copy_back ? p->out->contiguous() : Tensor(p->out); @@ -57,19 +143,22 @@ void run(void *planned_meta) { auto q = infinicore::adaptor::to_aten_tensor(p->q); #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); - auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); #elif defined(ENABLE_QY_API) Tensor k_cache_work = p->k_cache->contiguous(); Tensor v_cache_work = p->v_cache->contiguous(); auto k_cache = infinicore::adaptor::to_aten_tensor(k_cache_work); auto v_cache = infinicore::adaptor::to_aten_tensor(v_cache_work); #endif - auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); - auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); +#endif auto alibi_slopes = p->alibi_slopes ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; + auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + std::optional k_new = std::nullopt; std::optional v_new = std::nullopt; std::optional rotary_cos = std::nullopt; @@ -121,6 +210,7 @@ void run(void *planned_meta) { if (out_need_copy_back) { p->out->copy_from(out_work); } +#endif // ENABLE_CAMBRICON_API #else throw std::runtime_error("FlashAttention is not enabled in this build"); #endif diff --git a/src/infinicore/ops/multi_head_attention/mha_flashattn.cc b/src/infinicore/ops/multi_head_attention/mha_flashattn.cc index 13c96b94d..689e3db24 100644 --- a/src/infinicore/ops/multi_head_attention/mha_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention/mha_flashattn.cc @@ -8,6 +8,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) #include #endif +#if defined(ENABLE_CAMBRICON_API) +#include +#endif #endif namespace infinicore::op::mha_impl::flashattn { @@ -40,10 +43,10 @@ void *plan(Tensor out, namespace { // Only support nv for now -#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_FLASH_ATTN) && (defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_CAMBRICON_API)) // MetaX/hpcc pip `flash_attn_2_cuda` exports `mha_fwd` at global scope (no namespace), // while NVIDIA `flash-attn-nvidia.so` uses `flash::mha_fwd`. -#if defined(ENABLE_METAX_API) +#if defined(ENABLE_METAX_API) || defined(ENABLE_CAMBRICON_API) #define INFINICORE_FLASH_OP(name) ::name #else #define INFINICORE_FLASH_OP(name) flash::name @@ -54,8 +57,12 @@ namespace { void run(void *planned_meta) { // Only support nv for now -#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_FLASH_ATTN) && (defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_CAMBRICON_API)) +#if defined(ENABLE_CAMBRICON_API) + torch_mlu::mlu::MLUStreamGuard guard(infinicore::adaptor::get_mlu_stream()); +#else c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif auto *p = reinterpret_cast(planned_meta); auto q = infinicore::adaptor::to_aten_tensor(p->q); diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index f80107e7e..7283ca412 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -8,6 +8,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) #include #endif +#if defined(ENABLE_CAMBRICON_API) +#include +#endif #endif namespace infinicore::op::mha_varlen_impl::flashattn { @@ -50,7 +53,7 @@ namespace { #ifdef ENABLE_FLASH_ATTN // MetaX/hpcc pip `flash_attn_2_cuda` exports `mha_varlen_fwd` at global scope (no namespace), // while NVIDIA `flash-attn-nvidia.so` uses `flash::mha_varlen_fwd`. -#if defined(ENABLE_METAX_API) +#if defined(ENABLE_METAX_API) || defined(ENABLE_CAMBRICON_API) #define INFINICORE_FLASH_OP(name) ::name #else #define INFINICORE_FLASH_OP(name) flash::name @@ -61,7 +64,11 @@ namespace { void run(void *planned_meta) { #ifdef ENABLE_FLASH_ATTN +#if defined(ENABLE_CAMBRICON_API) + torch_mlu::mlu::MLUStreamGuard guard(infinicore::adaptor::get_mlu_stream()); +#else c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif auto *p = reinterpret_cast(planned_meta); auto q = infinicore::adaptor::to_aten_tensor(p->q); diff --git a/xmake.lua b/xmake.lua index ccae79cd2..4266f8305 100644 --- a/xmake.lua +++ b/xmake.lua @@ -1,5 +1,5 @@ add_rules("mode.debug", "mode.release") -add_requires("boost", {configs = {stacktrace = true}}) +add_requires("boost", {configs = {stacktrace = true, cmake = false}}) add_requires("pybind11") -- Define color codes @@ -468,6 +468,7 @@ target("infinicore_cpp_api") add_includedirs(INFINI_ROOT.."/include", { public = true }) add_linkdirs(INFINI_ROOT.."/lib") + add_rpathdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") if get_config("flash-attn") and get_config("flash-attn") ~= "" then @@ -478,6 +479,9 @@ target("infinicore_cpp_api") if has_config("metax-gpu") then add_deps("flash-attn-metax") end + if has_config("cambricon-mlu") then + add_deps("flash-attn-cambricon") + end if has_config("qy-gpu") then add_deps("flash-attn-qy") end @@ -517,24 +521,44 @@ target("infinicore_cpp_api") local TORCH_DIR = outdata target:add( - "includedirs", - path.join(TORCH_DIR, "include"), + "includedirs", + path.join(TORCH_DIR, "include"), path.join(TORCH_DIR, "include/torch/csrc/api/include"), { public = true }) - + target:add( "linkdirs", path.join(TORCH_DIR, "lib"), { public = true } ) - target:add( - "links", - "torch", - "c10", - "torch_cuda", - "c10_cuda", - { public = true } - ) + if has_config("cambricon-mlu") then + local TORCH_MLU_DIR = os.iorunv("python", {"-c", "import torch_mlu, os; print(os.path.dirname(torch_mlu.__file__))"}):trim() + target:add( + "includedirs", + path.join(TORCH_MLU_DIR, "csrc/include"), + path.join(TORCH_MLU_DIR, "csrc/include/api/include"), + { public = true } + ) + target:add("linkdirs", path.join(TORCH_MLU_DIR, "csrc/lib"), { public = true }) + target:add( + "links", + "torch", + "c10", + "torch_cpu", + "torch_mlu", + "torch_mlu_bangc", + { public = true } + ) + else + target:add( + "links", + "torch", + "c10", + "torch_cuda", + "c10_cuda", + { public = true } + ) + end end end) @@ -584,6 +608,7 @@ target("_infinicore") add_includedirs(INFINI_ROOT.."/include", { public = true }) add_linkdirs(INFINI_ROOT.."/lib") + add_rpathdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") add_files("src/infinicore/pybind11/**.cc") diff --git a/xmake/bang.lua b/xmake/bang.lua index ffa85ef6d..fddcc5541 100644 --- a/xmake/bang.lua +++ b/xmake/bang.lua @@ -1,5 +1,6 @@ local NEUWARE_HOME = os.getenv("NEUWARE_HOME") or "/usr/local/neuware" +local FLASH_ATTN_ROOT = get_config("flash-attn") add_includedirs(path.join(NEUWARE_HOME, "include"), {public = true}) add_linkdirs(path.join(NEUWARE_HOME, "lib64")) add_linkdirs(path.join(NEUWARE_HOME, "lib")) @@ -8,6 +9,39 @@ add_links("libcnnl.so") add_links("libcnnl_extra.so") add_links("libcnpapi.so") +local FLASH_ATTN_CAMBRICON_BANG_SO_CONTAINER_DEFAULT = + "/torch/venv3/pytorch/lib/python3.10/site-packages/flash_attn_2_bang.cpython-310-x86_64-linux-gnu.so" + +local function cambricon_flash_attn_bang_so_path() + local env_path = os.getenv("FLASH_ATTN_2_BANG_SO") + if env_path and env_path ~= "" then + env_path = env_path:trim() + if os.isfile(env_path) then + return env_path + end + print(string.format("warning: cambricon+flash-attn: FLASH_ATTN_2_BANG_SO is not a file: %s, fallback to python/container/default path", env_path)) + end + + local container_path = os.getenv("FLASH_ATTN_CAMBRICON_BANG_SO_CONTAINER") + if container_path and container_path ~= "" then + container_path = container_path:trim() + if os.isfile(container_path) then + return container_path + end + print(string.format("warning: cambricon+flash-attn: FLASH_ATTN_CAMBRICON_BANG_SO_CONTAINER is not a file: %s, fallback to python/default path", container_path)) + end + + if not os.isfile(FLASH_ATTN_CAMBRICON_BANG_SO_CONTAINER_DEFAULT) then + print( + string.format( + "warning: cambricon+flash-attn: expected %s; install flash-attn in the container, or export FLASH_ATTN_2_BANG_SO.", + FLASH_ATTN_CAMBRICON_BANG_SO_CONTAINER_DEFAULT + ) + ) + end + return FLASH_ATTN_CAMBRICON_BANG_SO_CONTAINER_DEFAULT +end + rule("mlu") set_extensions(".mlu") @@ -52,6 +86,50 @@ target("infiniop-cambricon") end target_end() +target("flash-attn-cambricon") + set_kind("phony") + set_default(false) + + if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then + before_build(function (target) + local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local TORCH_MLU_DIR = os.iorunv("python", {"-c", "import torch_mlu, os; print(os.path.dirname(torch_mlu.__file__))"}):trim() + local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() + local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() + + target:add( + "includedirs", + TORCH_DIR .. "/include", + TORCH_DIR .. "/include/torch/csrc/api/include", + TORCH_MLU_DIR .. "/csrc/include", + TORCH_MLU_DIR .. "/csrc/include/api/include", + PYTHON_INCLUDE, + {public = false} + ) + target:add("linkdirs", TORCH_DIR .. "/lib", TORCH_MLU_DIR .. "/csrc/lib", PYTHON_LIB_DIR, {public = false}) + end) + else + before_build(function (target) + print("Flash Attention not available, skipping flash-attn-cambricon integration") + end) + end +target_end() + +if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then + target("infinicore_cpp_api") + before_link(function (target) + local flash_so_cambricon = cambricon_flash_attn_bang_so_path() + local flash_dir_cambricon = path.directory(flash_so_cambricon) + local flash_name_cambricon = path.filename(flash_so_cambricon) + target:add( + "shflags", + "-Wl,--no-as-needed -L" .. flash_dir_cambricon .. " -l:" .. flash_name_cambricon .. " -Wl,-rpath," .. flash_dir_cambricon, + {force = true} + ) + end) + target_end() +end + target("infinirt-cambricon") set_kind("static") add_deps("infini-utils")