From 1e637102a7bce752fb2ce644ba4deefa39a172ca Mon Sep 17 00:00:00 2001 From: wooway777 Date: Tue, 27 Jan 2026 10:32:36 +0800 Subject: [PATCH 01/25] issue/987 - add .cpp files to ninetoothed includes --- xmake/nvidia.lua | 2 +- xmake/qy.lua | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 75086b8a1..a1133b15b 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -73,7 +73,7 @@ target("infiniop-nvidia") add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c") + add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp") end target_end() diff --git a/xmake/qy.lua b/xmake/qy.lua index ecef359a8..4a512e203 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -101,7 +101,7 @@ target("infiniop-qy") add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c") + add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp") end target_end() From 822a53417bd6e838d1e8f5ca73b83ccf5bd71314 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 23 Jan 2026 12:52:09 +0000 Subject: [PATCH 02/25] issue/978 - metax cuda graph impl and wrappings --- src/infiniop/devices/metax/metax_ht2mc.h | 13 +++++++++ src/infinirt/metax/infinirt_metax.cc | 37 ++++++++++++++++++++---- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index 2db1087d4..a1c8c1ffe 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -85,4 +85,17 @@ #define hcclSuccess mcclSuccess #define hcclCommDestroy mcclCommDestroy #define hcclAllReduce mcclAllReduce +#define hcStreamCaptureMode mcStreamCaptureMode +#define hcStreamCaptureModeGlobal mcStreamCaptureModeGlobal +#define hcStreamCaptureModeThreadLocal mcStreamCaptureModeThreadLocal +#define hcStreamCaptureModeRelaxed mcStreamCaptureModeRelaxed +#define hcStreamBeginCapture mcStreamBeginCapture +#define hcStreamEndCapture mcStreamEndCapture +#define hcGraph_t mcGraph_t +#define hcGraphExec_t mcGraphExec_t +#define hcGraphNode_t mcGraphNode_t +#define hcGraphInstantiate mcGraphInstantiate +#define hcGraphDestroy mcGraphDestroy +#define hcGraphExecDestroy mcGraphExecDestroy +#define hcGraphLaunch mcGraphLaunch #endif diff --git a/src/infinirt/metax/infinirt_metax.cc b/src/infinirt/metax/infinirt_metax.cc index aca187366..a9d69a9d8 100644 --- a/src/infinirt/metax/infinirt_metax.cc +++ b/src/infinirt/metax/infinirt_metax.cc @@ -154,15 +154,32 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { } infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + hcStreamCaptureMode graph_mode; + if (mode == INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) { + graph_mode = hcStreamCaptureModeGlobal; + } else if (mode == INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL) { + graph_mode = hcStreamCaptureModeThreadLocal; + } else if (mode == INFINIRT_STREAM_CAPTURE_MODE_RELAXED) { + graph_mode = hcStreamCaptureModeRelaxed; + } else { + return INFINI_STATUS_BAD_PARAM; + } + + CHECK_MACART(hcStreamBeginCapture((hcStream_t)stream, graph_mode)); + + return INFINI_STATUS_SUCCESS; } infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + hcGraph_t graph; + CHECK_MACART(hcStreamEndCapture((hcStream_t)stream, &graph)); + *graph_ptr = graph; + return INFINI_STATUS_SUCCESS; } infiniStatus_t graphDestroy(infinirtGraph_t graph) { - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + CHECK_MACART(hcGraphDestroy((hcGraph_t)graph)); + return INFINI_STATUS_SUCCESS; } infiniStatus_t graphInstantiate( @@ -171,15 +188,23 @@ infiniStatus_t graphInstantiate( infinirtGraphNode_t *node_ptr, char *log_buffer, size_t buffer_size) { - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + CHECK_MACART(hcGraphInstantiate( + (hcGraphExec_t *)graph_exec_ptr, + (hcGraph_t)graph, + (hcGraphNode_t *)node_ptr, + log_buffer, + buffer_size)); + return INFINI_STATUS_SUCCESS; } infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + CHECK_MACART(hcGraphExecDestroy((hcGraphExec_t)graph_exec)); + return INFINI_STATUS_SUCCESS; } infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + CHECK_MACART(hcGraphLaunch((hcGraphExec_t)graph_exec, (hcStream_t)stream)); + return INFINI_STATUS_SUCCESS; } } // namespace infinirt::metax From cc2cc3a1cb04c176f505604e5b5dd12ca2854185 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 26 Dec 2025 06:32:52 +0000 Subject: [PATCH 03/25] issue/846 - Refactor embedding to support device-side input and CUDA graph recording - Ensure embedding tensors are on the same device. Change format. - Optimize embedding kernel with vectorized memory access and __ldg - Add vectorized memory access using float4/float2, half2, and bfloat162 - Use __ldg instruction for read-only weight and indices access - Add memory alignment checks to enable vectorized paths - Add __restrict__ keywords for better compiler optimization - Implement dynamic block size selection based on embedding_dim --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/embedding.hpp | 7 + include/infiniop.h | 1 + include/infiniop/ops/embedding.h | 26 ++ python/infinicore/nn/functional/embedding.py | 5 +- src/infinicore/nn/embedding.cc | 82 +---- src/infinicore/ops/embedding/embedding.cc | 84 ++--- .../ops/embedding/embedding_infiniop.cc | 49 +++ .../ops/embedding/cpu/embedding_cpu.cc | 109 ++++++ .../ops/embedding/cpu/embedding_cpu.h | 8 + src/infiniop/ops/embedding/embedding.h | 54 +++ .../ops/embedding/nvidia/embedding_kernel.cuh | 178 ++++++++++ .../ops/embedding/nvidia/embedding_nvidia.cu | 169 ++++++++++ .../ops/embedding/nvidia/embedding_nvidia.cuh | 8 + src/infiniop/ops/embedding/operator.cc | 118 +++++++ .../EMBEDDING_GRAPH_RECORDING_COMPARISON.md | 159 +++++++++ .../nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md | 317 ++++++++++++++++++ test/infinicore/nn/embedding.py | 11 +- .../nn/test_embedding_graph_recording.py | 284 ++++++++++++++++ test/infinicore/ops/embedding.py | 20 +- 20 files changed, 1528 insertions(+), 162 deletions(-) create mode 100644 include/infiniop/ops/embedding.h create mode 100644 src/infinicore/ops/embedding/embedding_infiniop.cc create mode 100644 src/infiniop/ops/embedding/cpu/embedding_cpu.cc create mode 100644 src/infiniop/ops/embedding/cpu/embedding_cpu.h create mode 100644 src/infiniop/ops/embedding/embedding.h create mode 100644 src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh create mode 100644 src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu create mode 100644 src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh create mode 100644 src/infiniop/ops/embedding/operator.cc create mode 100644 test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md create mode 100644 test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md create mode 100644 test/infinicore/nn/test_embedding_graph_recording.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index a7249ec9d..88dd5b342 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -4,6 +4,7 @@ #include "ops/add_rms_norm.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/embedding.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" diff --git a/include/infinicore/ops/embedding.hpp b/include/infinicore/ops/embedding.hpp index 4fd9991c4..6be997134 100644 --- a/include/infinicore/ops/embedding.hpp +++ b/include/infinicore/ops/embedding.hpp @@ -4,6 +4,13 @@ namespace infinicore::op { +class Embedding { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor out, Tensor input, Tensor weight); + static common::OpDispatcher &dispatcher(); +}; + Tensor embedding(Tensor input, Tensor weight); void embedding_(Tensor out, Tensor input, Tensor weight); } // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..378a79a43 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -9,6 +9,7 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/embedding.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/layer_norm.h" diff --git a/include/infiniop/ops/embedding.h b/include/infiniop/ops/embedding.h new file mode 100644 index 000000000..e5ffc211d --- /dev/null +++ b/include/infiniop/ops/embedding.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_EMBEDDING_API_H__ +#define __INFINIOP_EMBEDDING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc); + +__C __export infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream); + +__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor( + infiniopEmbeddingDescriptor_t desc); + +#endif + diff --git a/python/infinicore/nn/functional/embedding.py b/python/infinicore/nn/functional/embedding.py index f346d380a..592a12290 100644 --- a/python/infinicore/nn/functional/embedding.py +++ b/python/infinicore/nn/functional/embedding.py @@ -22,9 +22,8 @@ def embedding( and (sparse is False) ), "Unsupported parameters." - assert "cpu" == input.device.type, ( - "The device of 'input' variable must be on the CPU." - ) + # Note: embedding now supports device-side input for graph recording + # The C++ implementation handles both CPU and device-side inputs if out is None: return Tensor(_infinicore.embedding(input._underlying, weight._underlying)) diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index 85645bf95..6aa86a4fa 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -43,80 +43,20 @@ Embedding::Embedding(size_t num_embeddings, } Tensor Embedding::forward(const Tensor &indices) const { - // Get the shape of indices - auto indices_shape = indices->shape(); - - // Output shape: indices_shape + [embedding_dim] - std::vector output_shape = indices_shape; - output_shape.push_back(embedding_dim_); - - // Create output tensor on the same device as weight - auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device()); - - // Flatten indices for sequential row copies - auto cpu_device = Device(Device::Type::CPU, 0); - auto indices_cpu = indices->to(cpu_device)->contiguous(); - - // Calculate total number of lookups - size_t num_lookups = 1; - for (auto dim : indices_shape) { - num_lookups *= dim; + // Ensure indices are on the same device as weight + // This avoids synchronous memcpy in ops layer which would hurt performance + Tensor indices_on_device = indices; + if (indices->device() != device_) { + indices_on_device = indices->to(device_); } - const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype()); - - // Source and destination base pointers - auto *weight_base = weight_->data(); - auto *out_base = out->data(); - - // Helper lambda to read index based on dtype with bounds checking - auto read_index = [&](size_t i) -> int64_t { - auto dtype = indices_cpu->dtype(); - if (dtype == DataType::I32) { - const auto *data = reinterpret_cast(indices_cpu->data()); - return static_cast(data[i]); - } else if (dtype == DataType::I64) { - const auto *data = reinterpret_cast(indices_cpu->data()); - return data[i]; - } else if (dtype == DataType::U32) { - const auto *data = reinterpret_cast(indices_cpu->data()); - return static_cast(data[i]); - } else if (dtype == DataType::U64) { - const auto *data = reinterpret_cast(indices_cpu->data()); - uint64_t val = data[i]; - // Check if value can fit in int64_t - if (val > static_cast(std::numeric_limits::max())) { - throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val)); - } - return static_cast(val); - } else { - throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast(dtype))); - } - }; - - if (weight_->device().getType() == Device::Type::CPU) { - // CPU path: memcpy row by row - for (size_t i = 0; i < num_lookups; ++i) { - int64_t idx = read_index(i); - if (idx < 0 || idx >= static_cast(num_embeddings_)) { - throw std::out_of_range( - "Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")"); - } - std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); - } - } else { - // Device path: use stream-ordered D2D copies - for (size_t i = 0; i < num_lookups; ++i) { - int64_t idx = read_index(i); - if (idx < 0 || idx >= static_cast(num_embeddings_)) { - throw std::out_of_range( - "Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")"); - } - context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); - } - } + // Ensure indices are contiguous for efficient access + // op::embedding now supports device-side input for graph recording + Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous(); - return out; + // Use op::embedding which now supports device-side input and batch dimension + // This enables full graph recording support without synchronization + return op::embedding(indices_contiguous, weight_); } std::string Embedding::extra_repr() const { diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index f1add0c97..96f19803c 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -1,15 +1,34 @@ #include "infinicore/ops/embedding.hpp" +#include "../../utils.hpp" #include "infinicore/context/context.hpp" #include +#include namespace infinicore::op { +common::OpDispatcher &Embedding::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void Embedding::execute(Tensor out, Tensor input, Tensor weight) { + // Check that all tensors are on the same device + // This is critical: if input is on CPU while out/weight are on GPU, + // passing CPU pointer to CUDA kernel will cause memory access errors + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight); + + // Set device context + infinicore::context::setDevice(out->device()); + + // Use dispatcher to lookup kernel (infiniop implementation) + dispatcher().lookup(out->device().getType())(out, input, weight); +} + Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 ) { auto input_shape = input->shape(); auto weight_shape = weight->shape(); - // auto vocab_size = weight_shape[0]; auto embedding_dim = weight_shape[1]; // Assign memory to out variables @@ -22,68 +41,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i } void embedding_(Tensor out, Tensor input, Tensor weight) { - assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); - assert(infinicore::Device::Type::CPU == input->device().getType()); - - auto input_shape = input->shape(); - auto weight_shape = weight->shape(); - auto embedding_dim = weight_shape[1]; - - // Calculate the number of token - Size counts = 1; - for (auto &v : input_shape) { - counts *= v; - } - - // the bytes of one token - const Size bytes = dsize(weight->dtype()) * embedding_dim; - auto *weight_ptr = weight->data(); - auto *out_ptr = out->data(); - - // copies - if (weight->device().getType() == Device::Type::CPU) { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - - } else { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - } + Embedding::execute(out, input, weight); } } // namespace infinicore::op diff --git a/src/infinicore/ops/embedding/embedding_infiniop.cc b/src/infinicore/ops/embedding/embedding_infiniop.cc new file mode 100644 index 000000000..dfbbb2f71 --- /dev/null +++ b/src/infinicore/ops/embedding/embedding_infiniop.cc @@ -0,0 +1,49 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/embedding.hpp" +#include + +namespace infinicore::op::embedding_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopEmbeddingDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor input, Tensor weight) { + size_t seed = hash_combine(out, input, weight); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopEmbeddingDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), input->desc(), weight->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + INFINICORE_CHECK_ERROR(infiniopEmbedding( + desc, + out->data(), + input->data(), + weight->data(), + context::getStream())); +} + +static bool registered = []() { + Embedding::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::embedding_impl::infiniop diff --git a/src/infiniop/ops/embedding/cpu/embedding_cpu.cc b/src/infiniop/ops/embedding/cpu/embedding_cpu.cc new file mode 100644 index 000000000..8e6648063 --- /dev/null +++ b/src/infiniop/ops/embedding/cpu/embedding_cpu.cc @@ -0,0 +1,109 @@ +#include "embedding_cpu.h" +#include "../../../../utils.h" +#include "../../../handle.h" +#include "../../../tensor.h" +#include + +namespace op::embedding::cpu { + +struct Descriptor::Opaque {}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto input_shape = input_desc->shape(); + auto weight_shape = weight_desc->shape(); + + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto output_shape = output_desc->shape(); + size_t embedding_dim = weight_shape[1]; + CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); + + for (size_t i = 0; i < input_shape.size(); ++i) { + CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + auto input_dtype = input_desc->dtype(); + auto weight_dtype = weight_desc->dtype(); + CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + + size_t vocab_size = weight_shape[0]; + + *desc_ptr = new Descriptor( + num_indices, + embedding_dim, + vocab_size, + input_dtype, + weight_dtype, + new Opaque{}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + + if (_num_indices == 0) { + return INFINI_STATUS_SUCCESS; + } + + size_t element_size = infiniSizeOf(_weight_dtype); + size_t row_bytes = _embedding_dim * element_size; + + if (_input_dtype == INFINI_DTYPE_I32) { + const int32_t *indices_ptr = reinterpret_cast(input); + const std::byte *weight_ptr = reinterpret_cast(weight); + std::byte *out_ptr = reinterpret_cast(output); + + for (size_t i = 0; i < _num_indices; ++i) { + int32_t idx = indices_ptr[i]; + if (idx >= 0 && static_cast(idx) < _vocab_size) { + std::memcpy(out_ptr + i * row_bytes, + weight_ptr + static_cast(idx) * row_bytes, + row_bytes); + } + } + } else if (_input_dtype == INFINI_DTYPE_I64) { + const int64_t *indices_ptr = reinterpret_cast(input); + const std::byte *weight_ptr = reinterpret_cast(weight); + std::byte *out_ptr = reinterpret_cast(output); + + for (size_t i = 0; i < _num_indices; ++i) { + int64_t idx = indices_ptr[i]; + if (idx >= 0 && static_cast(idx) < _vocab_size) { + std::memcpy(out_ptr + i * row_bytes, + weight_ptr + static_cast(idx) * row_bytes, + row_bytes); + } + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::embedding::cpu diff --git a/src/infiniop/ops/embedding/cpu/embedding_cpu.h b/src/infiniop/ops/embedding/cpu/embedding_cpu.h new file mode 100644 index 000000000..a5cc5b2d0 --- /dev/null +++ b/src/infiniop/ops/embedding/cpu/embedding_cpu.h @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_CPU_H__ +#define __EMBEDDING_CPU_H__ + +#include "../embedding.h" + +DESCRIPTOR(cpu) + +#endif // __EMBEDDING_CPU_H__ diff --git a/src/infiniop/ops/embedding/embedding.h b/src/infiniop/ops/embedding/embedding.h new file mode 100644 index 000000000..e0135dbfe --- /dev/null +++ b/src/infiniop/ops/embedding/embedding.h @@ -0,0 +1,54 @@ +#ifndef __EMBEDDING_H__ +#define __EMBEDDING_H__ + +#include "../../../utils.h" +#include "../../operator.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::embedding::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + size_t _num_indices; \ + size_t _embedding_dim; \ + size_t _vocab_size; \ + infiniDtype_t _input_dtype; \ + infiniDtype_t _weight_dtype; \ + \ + Descriptor( \ + size_t num_indices, \ + size_t embedding_dim, \ + size_t vocab_size, \ + infiniDtype_t input_dtype, \ + infiniDtype_t weight_dtype, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _num_indices(num_indices), \ + _embedding_dim(embedding_dim), \ + _vocab_size(vocab_size), \ + _input_dtype(input_dtype), \ + _weight_dtype(weight_dtype) {} \ + \ + public: \ + ~Descriptor(); \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t output_desc, \ + infiniopTensorDescriptor_t input_desc, \ + infiniopTensorDescriptor_t weight_desc); \ + \ + infiniStatus_t calculate( \ + void *output, \ + const void *input, \ + const void *weight, \ + void *stream) const; \ + }; \ + } + +#endif // __EMBEDDING_H__ diff --git a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh new file mode 100644 index 000000000..0e85b5f6a --- /dev/null +++ b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh @@ -0,0 +1,178 @@ +#ifndef __EMBEDDING_CUDA_KERNEL_CUH__ +#define __EMBEDDING_CUDA_KERNEL_CUH__ + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include +#include +#include + +namespace op::embedding::nvidia { + +// Helper function to check memory alignment +__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) { + // Use size_t for pointer arithmetic in device code (more compatible) + return (reinterpret_cast(ptr) % alignment == 0); +} + +// Vectorized copy for float type using float4 +template +__forceinline__ __device__ void copyVectorizedFloat4( + float *__restrict__ dst, + const float *__restrict__ src, + size_t embedding_dim) { + // Use float4 for vectorized access (16 bytes, 4 floats) + const float4 *src_vec = reinterpret_cast(src); + float4 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 4; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining elements + size_t remaining = embedding_dim % 4; + if (remaining > 0) { + size_t offset = vec_count * 4; + for (size_t i = 0; i < remaining; ++i) { + dst[offset + i] = __ldg(&src[offset + i]); + } + } +} + +// Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes) +template +__forceinline__ __device__ void copyVectorizedFloat2( + float *__restrict__ dst, + const float *__restrict__ src, + size_t embedding_dim) { + // Use float2 for vectorized access (8 bytes, 2 floats) + const float2 *src_vec = reinterpret_cast(src); + float2 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +// Vectorized copy for half type using half2 +template +__forceinline__ __device__ void copyVectorizedHalf2( + half *__restrict__ dst, + const half *__restrict__ src, + size_t embedding_dim) { + // Use half2 for vectorized access (4 bytes, 2 halfs) + const half2 *src_vec = reinterpret_cast(src); + half2 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +// Vectorized copy for bfloat16 type using bfloat162 +template +__forceinline__ __device__ void copyVectorizedBFloat162( + cuda_bfloat16 *__restrict__ dst, + const cuda_bfloat16 *__restrict__ src, + size_t embedding_dim) { + // Use bfloat162 for vectorized access (4 bytes, 2 bfloat16s) + const cuda_bfloat162 *src_vec = reinterpret_cast(src); + cuda_bfloat162 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +// Scalar copy fallback with __ldg optimization +template +__forceinline__ __device__ void copyScalar( + T *__restrict__ dst, + const T *__restrict__ src, + size_t embedding_dim) { + // Scalar copy with __ldg for read-only weight + for (size_t i = 0; i < embedding_dim; ++i) { + dst[i] = __ldg(&src[i]); + } +} + +template +INFINIOP_CUDA_KERNEL embeddingKernel( + T *__restrict__ output, + const IndexType *__restrict__ indices, + const T *__restrict__ weight, + size_t num_indices, + size_t embedding_dim, + size_t vocab_size) { + // Calculate global thread index + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < num_indices) { + // Get the index value + IndexType index_val = __ldg(&indices[idx]); + + // Bounds check - handle negative indices gracefully + if (index_val >= 0 && static_cast(index_val) < vocab_size) { + // Copy embedding vector from weight to output + const T *src = weight + static_cast(index_val) * embedding_dim; + T *dst = output + idx * embedding_dim; + + // Choose optimal copy strategy based on type and alignment + if constexpr (std::is_same_v) { + // Check alignment for float4 (16 bytes) + bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16); + if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) { + copyVectorizedFloat4(dst, src, embedding_dim); + } else if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + // Try float2 if not aligned to 16 bytes + copyVectorizedFloat2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use half2 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedHalf2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use bfloat162 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedBFloat162(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else { + // Fallback to scalar copy with __ldg + copyScalar(dst, src, embedding_dim); + } + } + } +} + +} // namespace op::embedding::nvidia + +#endif // __EMBEDDING_CUDA_KERNEL_CUH__ diff --git a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu new file mode 100644 index 000000000..b714b0aa4 --- /dev/null +++ b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu @@ -0,0 +1,169 @@ +#include "../../../../utils.h" +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../../../tensor.h" +#include "embedding_kernel.cuh" +#include "embedding_nvidia.cuh" +#include + +namespace op::embedding::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto input_shape = input_desc->shape(); + auto weight_shape = weight_desc->shape(); + + // Validate shapes + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + // Check output shape matches input shape + embedding_dim + auto output_shape = output_desc->shape(); + size_t embedding_dim = weight_shape[1]; + CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); + + for (size_t i = 0; i < input_shape.size(); ++i) { + CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + // Validate dtypes + auto input_dtype = input_desc->dtype(); + auto weight_dtype = weight_desc->dtype(); + CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + // Calculate number of indices (supporting batch dimension) + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + + size_t vocab_size = weight_shape[0]; + + *desc_ptr = new Descriptor( + num_indices, + embedding_dim, + vocab_size, + input_dtype, + weight_dtype, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + + if (_num_indices == 0) { + return INFINI_STATUS_SUCCESS; + } + + auto cuda_stream = reinterpret_cast(stream); + + // Dynamic block size optimization based on embedding_dim + // Smaller embedding_dim benefits from larger block size (better occupancy) + // Larger embedding_dim benefits from smaller block size (more registers per thread) + size_t block_size = 256; // Default + if (_embedding_dim <= 64) { + block_size = 512; // Small embedding_dim: use larger block for better occupancy + } else if (_embedding_dim >= 1024) { + block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure + } + + size_t grid_size = (_num_indices + block_size - 1) / block_size; + + // Launch kernel based on dtypes + if (_input_dtype == INFINI_DTYPE_I32) { + const int32_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_input_dtype == INFINI_DTYPE_I64) { + const int64_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // Check for kernel launch errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + return INFINI_STATUS_INTERNAL_ERROR; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::embedding::nvidia diff --git a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh new file mode 100644 index 000000000..c6b966d8d --- /dev/null +++ b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_CUDA_H__ +#define __EMBEDDING_CUDA_H__ + +#include "../embedding.h" + +DESCRIPTOR(nvidia) + +#endif // __EMBEDDING_CUDA_H__ diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc new file mode 100644 index 000000000..50f2f05ed --- /dev/null +++ b/src/infiniop/ops/embedding/operator.cc @@ -0,0 +1,118 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/embedding.h" + +#ifdef ENABLE_CPU_API +#include "cpu/embedding_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#include "nvidia/embedding_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::embedding::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + input_desc, \ + weight_desc) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#if defined(ENABLE_ILUVATAR_API) + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#if defined(ENABLE_QY_API) + CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#if defined(ENABLE_HYGON_API) + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(output, input, weight, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#if defined(ENABLE_ILUVATAR_API) + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#if defined(ENABLE_QY_API) + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#if defined(ENABLE_HYGON_API) + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#if defined(ENABLE_ILUVATAR_API) + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#if defined(ENABLE_QY_API) + DELETE(INFINI_DEVICE_QY, nvidia); +#endif +#if defined(ENABLE_HYGON_API) + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif + } + +#undef DELETE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} diff --git a/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md b/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md new file mode 100644 index 000000000..686c10a1b --- /dev/null +++ b/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md @@ -0,0 +1,159 @@ +# Embedding 图录制支持对比 + +## 改动前后对比 + +### ❌ 改动前:不支持图录制 + +**关键问题代码**(在 `nn::Embedding::forward` 中): +```cpp +// 改动前的实现 +Tensor Embedding::forward(const Tensor &indices) const { + auto cpu_device = Device(Device::Type::CPU, 0); + auto indices_cpu = indices->to(cpu_device)->contiguous(); // ❌ 同步操作! + + // ... 后续处理 +} +``` + +**问题分析**: +1. `indices->to(cpu_device)` 会触发 **同步的 D2H(Device-to-Host)内存拷贝** +2. CUDA Graph 录制要求所有操作都是**异步的**,不能有同步点 +3. 同步操作会导致图录制失败或产生错误 + +**验证方法**: +```python +# 改动前:这个操作会失败或产生同步 +input_ids_device = infinicore.from_list(..., device="cuda:0") # 设备端输入 +output = embedding.forward(input_ids_device) # ❌ 内部会同步拷贝到 CPU +``` + +--- + +### ✅ 改动后:支持图录制 + +**关键改进代码**: +```cpp +// 改动后的实现 +Tensor Embedding::forward(const Tensor &indices) const { + Tensor indices_contiguous = indices->is_contiguous() ? indices : indices->contiguous(); + return op::embedding(indices_contiguous, weight_); // ✅ 直接使用设备端 kernel +} +``` + +**改进点**: +1. **移除了同步操作**:不再调用 `indices->to(cpu_device)` +2. **使用设备端 CUDA kernel**:通过 InfiniOP 调用 `embeddingKernel`,完全在设备端执行 +3. **完全异步**:所有操作都在 CUDA stream 上异步执行 + +**实现位置**: +- CUDA Kernel: `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` +- Kernel 启动:使用 `cudaStream_t`,完全异步 +- 无同步点:没有 `cudaDeviceSynchronize()` 或 D2H 拷贝 + +**验证方法**: +```python +# 改动后:这个操作完全异步,支持图录制 +input_ids_device = infinicore.from_list(..., device="cuda:0") # 设备端输入 +output = embedding.forward(input_ids_device) # ✅ 直接使用设备端 kernel,无同步 +``` + +--- + +## 验证方法 + +### 方法 1: 代码检查 + +**检查点**: +1. ✅ 是否有 `->to(cpu_device)` 调用? +2. ✅ 是否有 `synchronize()` 调用? +3. ✅ 是否有设备端 kernel 实现? + +**改动前**: +```cpp +// ❌ 有同步操作 +auto indices_cpu = indices->to(cpu_device)->contiguous(); +``` + +**改动后**: +```cpp +// ✅ 无同步操作,直接使用设备端 kernel +return op::embedding(indices_contiguous, weight_); +``` + +### 方法 2: CUDA Graph API 测试 + +运行测试脚本: +```bash +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +**预期结果**: +- ✅ 改动后:图录制成功 +- ❌ 改动前:图录制失败(因为同步操作) + +### 方法 3: 设备端输入测试 + +**关键测试**: +```python +# 创建设备端输入 +input_ids = infinicore.from_list([[1, 2, 3]], dtype=int64, device="cuda:0") + +# 执行 forward +output = embedding.forward(input_ids) # 改动前会失败或同步,改动后成功 +``` + +**改动前**: +- 需要先将 `input_ids` 拷贝到 CPU +- 触发同步操作,无法图录制 + +**改动后**: +- 直接使用设备端 `input_ids` +- 完全异步,支持图录制 + +--- + +## 技术细节对比 + +| 特性 | 改动前 | 改动后 | +|------|--------|--------| +| **输入设备** | 必须在 CPU | 支持设备端 | +| **同步操作** | ❌ 有(D2H拷贝) | ✅ 无 | +| **Kernel位置** | CPU 实现 | CUDA kernel | +| **图录制支持** | ❌ 不支持 | ✅ 支持 | +| **Batch维度** | ✅ 支持 | ✅ 支持 | +| **性能** | 较慢(同步开销) | 更快(异步) | + +--- + +## 关键代码位置 + +### 改动前的问题代码 +- `src/infinicore/nn/embedding.cc` (旧版本) + - 第58行:`indices->to(cpu_device)->contiguous()` ❌ + +### 改动后的实现 +- `src/infinicore/nn/embedding.cc` (新版本) + - 第48行:`indices->is_contiguous() ? indices : indices->contiguous()` ✅ + - 第52行:`return op::embedding(indices_contiguous, weight_)` ✅ + +- `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` + - CUDA kernel 实现,完全异步 ✅ + +- `src/infinicore/ops/embedding/embedding_infiniop.cc` + - InfiniOP 包装,调用设备端 kernel ✅ + +--- + +## 总结 + +**改动前的关键问题**: +- ❌ `indices->to(cpu_device)` 触发同步 D2H 拷贝 +- ❌ 无法进行 CUDA Graph 录制 +- ❌ 性能较差(同步开销) + +**改动后的改进**: +- ✅ 移除所有同步操作 +- ✅ 使用设备端 CUDA kernel +- ✅ 完全支持 CUDA Graph 录制 +- ✅ 性能更好(完全异步) + diff --git a/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md b/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md new file mode 100644 index 000000000..e5e60db2b --- /dev/null +++ b/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md @@ -0,0 +1,317 @@ +# Embedding 图录制测试使用指南 + +## 🚀 快速开始 + +### 运行测试 + +```bash +cd /home/zhuyue/codes/InfiniCore +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +--- + +## 📊 改动前后对比 + +### ❌ 改动前:不支持图录制 + +#### 1. 运行测试 + +```bash +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +#### 2. 预期输出 + +``` +============================================================ +Embedding 图录制支持验证 +============================================================ +============================================================ +测试 Embedding 图录制支持 +============================================================ + +1. 输入张量信息: + - Shape: [4, 32] + - Device: cuda + - Dtype: int64 + +2. 尝试 CUDA Graph 录制... + 使用 PyTorch CUDA Graph API 测试... + ✗ 图录制失败: [错误信息] + ✗ Embedding 不支持 CUDA Graph 录制(可能包含同步操作) + +3. 简化验证:检查异步操作支持 + ✓ 输入在设备上 + ⚠ 操作可能包含同步点(事件立即完成) ← 关键:说明有同步操作 + ✓ Forward 执行时间: X.XXX ms + ✓ 输出形状: [4, 32, 128] + ✓ 输出设备: cuda + ✗ 输出验证失败 + +============================================================ +测试 Embedding 设备端输入支持 +============================================================ + +测试 1: 设备端输入 + ✗ 设备端输入失败: [错误信息] + +============================================================ +测试结果总结 +============================================================ +CUDA Graph 录制: ✗ 失败 +设备端输入: ✗ 失败 +============================================================ +✗ 部分测试失败,Embedding 可能不完全支持图录制 +============================================================ +``` + +#### 3. 关键失败点 + +- **图录制失败**:因为代码中有 `indices->to(cpu_device)` 同步操作 +- **设备端输入失败**:需要先将输入拷贝到 CPU +- **异步验证显示同步点**:事件立即完成,说明有同步操作 + +--- + +### ✅ 改动后:支持图录制 + +#### 1. 运行测试 + +```bash +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +#### 2. 预期输出 + +``` +============================================================ +Embedding 图录制支持验证 +============================================================ +============================================================ +测试 Embedding 图录制支持 +============================================================ + +1. 输入张量信息: + - Shape: [4, 32] + - Device: cuda + - Dtype: int64 + +2. 尝试 CUDA Graph 录制... + 使用 PyTorch CUDA Graph API 测试... + ✓ 成功完成图录制! + ✓ Embedding 支持 CUDA Graph 录制 + ✓ 图可以成功重放 + +============================================================ +测试 Embedding 设备端输入支持 +============================================================ + +测试 1: 设备端输入 + ✓ 设备端输入成功 + - 输入设备: cuda + - 输出设备: cuda + - 输出形状: [1, 5, 64] + +============================================================ +测试结果总结 +============================================================ +CUDA Graph 录制: ✓ 通过 +设备端输入: ✓ 通过 +============================================================ +✓ 所有测试通过!Embedding 支持图录制 +============================================================ +``` + +#### 3. 关键成功点 + +- **图录制成功**:所有操作都是异步的,无同步点 +- **设备端输入成功**:直接支持设备端输入,无需拷贝 +- **图可以重放**:验证图录制的正确性 + +--- + +## 🔍 如何判断当前是改动前还是改动后? + +### 方法 1: 代码检查(最快) + +```bash +# 检查是否有同步操作 +grep -n "to(cpu_device)" src/infinicore/nn/embedding.cc + +# 结果解读: +# - 有输出 → ❌ 改动前(不支持图录制) +# - 无输出 → ✅ 改动后(支持图录制) +``` + +### 方法 2: 检查设备端实现 + +```bash +# 检查是否有设备端 CUDA kernel +ls src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu + +# 结果解读: +# - 不存在 → ❌ 改动前(不支持图录制) +# - 存在 → ✅ 改动后(支持图录制) +``` + +### 方法 3: 运行测试(最准确) + +```bash +python test/infinicore/nn/test_embedding_graph_recording.py + +# 查看 "CUDA Graph 录制" 测试结果: +# - ✓ 通过 → ✅ 改动后(支持图录制) +# - ✗ 失败 → ❌ 改动前(不支持图录制) +``` + +--- + +## 📝 测试内容详解 + +### 测试 1: CUDA Graph 录制 + +**目的**:验证 embedding 是否可以在 CUDA Graph 中录制 + +**工作原理**: +1. 使用 PyTorch 的 `torch.cuda.CUDAGraph()` API +2. 在图录制模式下执行 `embedding.forward()` +3. 如果包含同步操作,录制会失败 +4. 如果完全异步,录制会成功 + +**改动前**: +- ❌ 录制失败:因为 `indices->to(cpu_device)` 触发同步 + +**改动后**: +- ✅ 录制成功:使用设备端 CUDA kernel,完全异步 + +### 测试 2: 设备端输入支持 + +**目的**:验证 embedding 是否支持设备端输入 + +**工作原理**: +1. 创建设备端的 `input_ids` +2. 直接调用 `embedding.forward(input_ids)` +3. 检查是否成功且输出在设备上 + +**改动前**: +- ❌ 可能需要先将输入拷贝到 CPU(同步操作) + +**改动后**: +- ✅ 直接支持设备端输入(完全异步) + +### 测试 3: 异步操作验证(备用) + +**目的**:当 CUDA Graph API 不可用时,使用事件验证异步性 + +**工作原理**: +1. 使用 `DeviceEvent` 记录操作时间 +2. 检查操作是否立即完成(同步)或异步执行 + +**改动前**: +- ⚠️ 事件立即完成,说明有同步操作 + +**改动后**: +- ✅ 事件未立即完成,说明是异步操作 + +--- + +## 🛠️ 故障排查 + +### 问题 1: PyTorch 版本不支持 CUDA Graph + +**现象**: +``` +⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法 +``` + +**解决**: +- 需要 PyTorch 2.0+ 版本 +- 测试会自动降级到简化验证方法 +- 简化验证也能检测是否支持图录制 + +### 问题 2: CUDA 不可用 + +**现象**: +``` +⚠ CUDA 不可用,跳过图录制测试 +``` + +**解决**: +- 确保 CUDA 设备可用 +- 测试需要 CUDA 环境 + +### 问题 3: 测试失败但不确定原因 + +**检查清单**: +1. ✅ 确认代码已编译(特别是 CUDA 支持) +2. ✅ 确认 CUDA 设备可用 +3. ✅ 检查 `src/infinicore/nn/embedding.cc` 是否还有 `to(cpu_device)` +4. ✅ 检查是否有 `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` + +--- + +## 💡 快速验证脚本 + +创建一个简单的验证脚本: + +```bash +#!/bin/bash +# quick_check.sh + +cd /home/zhuyue/codes/InfiniCore + +echo "=== 1. 代码检查 ===" +if grep -q "to(cpu_device)" src/infinicore/nn/embedding.cc; then + echo "❌ 改动前:发现同步操作 to(cpu_device)" +else + echo "✅ 改动后:无同步操作" +fi + +echo "" +echo "=== 2. 设备端实现检查 ===" +if [ -f "src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu" ]; then + echo "✅ 改动后:有设备端 CUDA kernel" +else + echo "❌ 改动前:无设备端 CUDA kernel" +fi + +echo "" +echo "=== 3. 运行测试 ===" +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +使用方法: +```bash +chmod +x quick_check.sh +./quick_check.sh +``` + +--- + +## 📋 总结 + +### 改动前特征 + +| 检查项 | 结果 | +|--------|------| +| 代码中有 `to(cpu_device)` | ✅ 有 | +| 有设备端 CUDA kernel | ❌ 无 | +| 图录制测试 | ❌ 失败 | +| 设备端输入 | ❌ 失败 | + +### 改动后特征 + +| 检查项 | 结果 | +|--------|------| +| 代码中有 `to(cpu_device)` | ❌ 无 | +| 有设备端 CUDA kernel | ✅ 有 | +| 图录制测试 | ✅ 成功 | +| 设备端输入 | ✅ 成功 | + +### 最简单的判断方法 + +**运行测试脚本**,查看 "CUDA Graph 录制" 测试结果: +- ✅ **通过** → 支持图录制(改动后) +- ❌ **失败** → 不支持图录制(改动前) + diff --git a/test/infinicore/nn/embedding.py b/test/infinicore/nn/embedding.py index 667713537..023bc7762 100644 --- a/test/infinicore/nn/embedding.py +++ b/test/infinicore/nn/embedding.py @@ -114,14 +114,9 @@ def torch_operator(self, x, weight): def infinicore_operator(self, x, weight): """InfiniCore nn.Embedding implementation""" - - if x.device.type != "cpu": - # 将 input的数据 转移到 cpu 上 - x_torch = convert_infinicore_to_torch(x) - x_torch_cpu = x_torch.contiguous().cpu() - - x = infinicore.from_torch(x_torch_cpu) - + # Note: embedding now supports device-side input for graph recording + # No need to convert to CPU anymore - the implementation handles both CPU and device inputs + num_embeddings, embedding_dim = weight.shape model = infinicore.nn.Embedding( diff --git a/test/infinicore/nn/test_embedding_graph_recording.py b/test/infinicore/nn/test_embedding_graph_recording.py new file mode 100644 index 000000000..405f71e0d --- /dev/null +++ b/test/infinicore/nn/test_embedding_graph_recording.py @@ -0,0 +1,284 @@ +""" +测试 embedding 是否支持 CUDA Graph 录制 + +使用方法: + python test/infinicore/nn/test_embedding_graph_recording.py + +关键验证点: +1. 改动前:indices->to(cpu_device) 会触发同步的 D2H 拷贝,导致图录制失败 +2. 改动后:使用设备端 CUDA kernel,完全异步,支持图录制 + +预期结果: +- 改动前:图录制失败,设备端输入可能失败 +- 改动后:图录制成功,设备端输入成功 +""" + +import infinicore +import torch +import ctypes + + +def test_embedding_graph_recording(): + """测试 embedding 是否支持 CUDA Graph 录制""" + print("=" * 60) + print("测试 Embedding 图录制支持") + print("=" * 60) + + # 检查是否有 CUDA + if not torch.cuda.is_available(): + print("⚠ CUDA 不可用,跳过图录制测试") + return False + + device = infinicore.device("cuda", 0) + + # 创建 embedding 模块 + vocab_size = 1000 + embedding_dim = 128 + embedding = infinicore.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + dtype=infinicore.float32, + device=device + ) + + # 创建设备端的 input_ids(这是关键:改动前不支持,改动后支持) + batch_size = 4 + seq_len = 32 + input_ids_device = infinicore.from_list( + [[i % vocab_size for i in range(seq_len)] for _ in range(batch_size)], + dtype=infinicore.int64, + device=device + ) + + print(f"\n1. 输入张量信息:") + print(f" - Shape: {input_ids_device.shape}") + print(f" - Device: {input_ids_device.device.type}") + print(f" - Dtype: {input_ids_device.dtype}") + + # 尝试使用 CUDA Graph 录制 + print(f"\n2. 尝试 CUDA Graph 录制...") + + # 使用 PyTorch 的 CUDA Graph API 进行测试(更简单可靠) + try: + # 设置设备 + infinicore.set_device(device) + + # 使用 PyTorch 的 CUDA Graph API + # 注意:PyTorch 2.0+ 支持 torch.cuda.graph + try: + # 方法 1: 使用 PyTorch 的 CUDA Graph(推荐) + print(" 使用 PyTorch CUDA Graph API 测试...") + + # 创建 warmup 输入 + warmup_input = input_ids_device + + # Warmup(图录制前需要先执行一次,包括内存分配) + warmup_output = embedding.forward(warmup_input) + infinicore.sync_stream() # 同步确保 warmup 完成 + + # 预先分配输出张量(CUDA Graph 不支持动态内存分配) + # 输出形状: input_shape + [embedding_dim] + output_shape = list(input_ids_device.shape) + [embedding_dim] + output = infinicore.empty( + output_shape, + dtype=embedding.weight.dtype, + device=device + ) + + # Warmup embedding(确保内存分配完成) + import infinicore.nn.functional as F + F.embedding(warmup_input, embedding.weight, out=output) + infinicore.sync_stream() + + # 开始图录制(使用预先分配的 output) + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + # 使用 embedding 的 out 参数(in-place),传入预先分配的 output + F.embedding(input_ids_device, embedding.weight, out=output) + + print(" ✓ 成功完成图录制!") + print(" ✓ Embedding 支持 CUDA Graph 录制") + + # 验证图可以重复执行 + graph.replay() + infinicore.sync_stream() + + print(" ✓ 图可以成功重放") + return True + + except AttributeError: + # PyTorch 版本可能不支持 torch.cuda.graph + print(" ⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法") + return test_embedding_async_verification(embedding, input_ids_device) + except RuntimeError as e: + error_msg = str(e) + if "capture" in error_msg.lower() or "graph" in error_msg.lower(): + print(f" ✗ 图录制失败: {e}") + print(" ✗ Embedding 不支持 CUDA Graph 录制(可能包含同步操作)") + return False + else: + print(f" ⚠ 图录制测试异常: {e}") + return test_embedding_async_verification(embedding, input_ids_device) + + except Exception as e: + print(f" ⚠ 图录制测试异常: {e}") + print(" 使用简化验证方法...") + import traceback + traceback.print_exc() + return test_embedding_async_verification(embedding, input_ids_device) + + +def test_embedding_async_verification(embedding, input_ids_device): + """ + 简化验证:检查是否有同步操作 + + 关键检查点: + 1. 输入是否可以在设备上(改动前需要 CPU,改动后支持设备) + 2. 操作是否完全异步(没有同步点) + """ + print("\n3. 简化验证:检查异步操作支持") + + # 验证 1: 输入可以在设备上 + if input_ids_device.device.type != "cuda": + print(" ✗ 输入不在设备上,无法验证") + return False + + print(" ✓ 输入在设备上") + + # 验证 2: 执行 forward,检查是否有同步操作 + # 如果改动前,这里会调用 indices->to(cpu_device),触发同步 + # 如果改动后,直接使用设备端 kernel,完全异步 + + try: + # 记录开始时间 + start_event = infinicore.DeviceEvent(enable_timing=True) + end_event = infinicore.DeviceEvent(enable_timing=True) + + start_event.record() + output = embedding.forward(input_ids_device) + end_event.record() + + # 不立即同步,检查操作是否异步 + # 如果操作是异步的,query 应该返回 False(未完成) + # 如果操作是同步的,可能已经完成 + + # 等待一小段时间 + import time + time.sleep(0.001) # 1ms + + # 检查事件状态 + is_complete = end_event.query() + + if not is_complete: + print(" ✓ 操作是异步的(事件未立即完成)") + else: + print(" ⚠ 操作可能包含同步点(事件立即完成)") + + # 同步并测量时间 + end_event.synchronize() + elapsed = start_event.elapsed_time(end_event) + + print(f" ✓ Forward 执行时间: {elapsed:.3f} ms") + print(f" ✓ 输出形状: {output.shape}") + print(f" ✓ 输出设备: {output.device.type}") + + # 验证输出正确性 + embedding_dim = embedding.embedding_dim() + expected_shape = (*input_ids_device.shape, embedding_dim) + if output.device.type == "cuda" and output.shape == expected_shape: + print(" ✓ 输出在设备上,形状正确") + return True + else: + print(f" ✗ 输出验证失败") + print(f" 期望形状: {expected_shape}, 实际形状: {output.shape}") + print(f" 期望设备: cuda, 实际设备: {output.device.type}") + return False + + except Exception as e: + print(f" ✗ 验证失败: {e}") + import traceback + traceback.print_exc() + return False + + +def test_embedding_device_input_support(): + """测试 embedding 是否支持设备端输入""" + print("\n" + "=" * 60) + print("测试 Embedding 设备端输入支持") + print("=" * 60) + + if not torch.cuda.is_available(): + print("⚠ CUDA 不可用,跳过测试") + return False + + device = infinicore.device("cuda", 0) + vocab_size = 100 + embedding_dim = 64 + + embedding = infinicore.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + dtype=infinicore.float32, + device=device + ) + + # 测试 1: 设备端输入(改动后支持) + print("\n测试 1: 设备端输入") + try: + input_ids_device = infinicore.from_list( + [[1, 2, 3, 4, 5]], + dtype=infinicore.int64, + device=device + ) + output = embedding.forward(input_ids_device) + print(f" ✓ 设备端输入成功") + print(f" - 输入设备: {input_ids_device.device.type}") + print(f" - 输出设备: {output.device.type}") + print(f" - 输出形状: {output.shape}") + return True + except Exception as e: + print(f" ✗ 设备端输入失败: {e}") + return False + + +def main(): + """主测试函数""" + print("\n" + "=" * 60) + print("Embedding 图录制支持验证") + print("=" * 60) + + results = [] + + # 测试 1: 图录制支持 + result1 = test_embedding_graph_recording() + results.append(("CUDA Graph 录制", result1)) + + # 测试 2: 设备端输入支持 + result2 = test_embedding_device_input_support() + results.append(("设备端输入", result2)) + + # 总结 + print("\n" + "=" * 60) + print("测试结果总结") + print("=" * 60) + + all_passed = True + for test_name, result in results: + status = "✓ 通过" if result else "✗ 失败" + print(f"{test_name}: {status}") + if not result: + all_passed = False + + print("\n" + "=" * 60) + if all_passed: + print("✓ 所有测试通过!Embedding 支持图录制") + else: + print("✗ 部分测试失败,Embedding 可能不完全支持图录制") + print("=" * 60) + + return all_passed + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/test/infinicore/ops/embedding.py b/test/infinicore/ops/embedding.py index a8bdc00b8..6cb7755af 100644 --- a/test/infinicore/ops/embedding.py +++ b/test/infinicore/ops/embedding.py @@ -102,23 +102,9 @@ def torch_operator(self, *args, out=None, **kwargs): def infinicore_operator(self, input, weight, out=None, **kwargs): """InfiniCore Embedding implementation""" - - if input.device.type == "cpu": - input_cpu = input - else: - # 将 input的数据 转移到 cpu 上 - torch_reference = torch.zeros( - input.shape, - dtype=to_torch_dtype(input.dtype), - device="cpu" if "cpu" == input.device.type else "cuda", - ) - torch_reference = convert_infinicore_to_torch(input) - torch_reference = torch_reference.contiguous().cpu() - - # 创建cpu的 input - input_cpu = infinicore_tensor_from_torch(torch_reference) - - return infinicore.nn.functional.embedding(input_cpu, weight, out=out) + # Note: embedding now supports device-side input for graph recording + # No need to convert to CPU anymore - the implementation handles both CPU and device inputs + return infinicore.nn.functional.embedding(input, weight, out=out) def main(): From 835209e715af6527aa7cfabd737fd0af13bc60c7 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 8 Jan 2026 15:56:50 +0800 Subject: [PATCH 04/25] issue/900 - support embedding on iluvatar, metax, and moore --- .../{nvidia => cuda}/embedding_kernel.cuh | 62 ----- .../ops/embedding/metax/embedding_metax.cuh | 8 + .../ops/embedding/metax/embedding_metax.maca | 217 +++++++++++++++++ .../ops/embedding/moore/embedding_moore.h | 8 + .../ops/embedding/moore/embedding_moore.mu | 227 ++++++++++++++++++ .../embedding/moore/embedding_moore_kernel.h | 116 +++++++++ .../ops/embedding/nvidia/embedding_nvidia.cu | 57 ++++- src/infiniop/ops/embedding/operator.cc | 65 +++-- 8 files changed, 678 insertions(+), 82 deletions(-) rename src/infiniop/ops/embedding/{nvidia => cuda}/embedding_kernel.cuh (59%) create mode 100644 src/infiniop/ops/embedding/metax/embedding_metax.cuh create mode 100644 src/infiniop/ops/embedding/metax/embedding_metax.maca create mode 100644 src/infiniop/ops/embedding/moore/embedding_moore.h create mode 100644 src/infiniop/ops/embedding/moore/embedding_moore.mu create mode 100644 src/infiniop/ops/embedding/moore/embedding_moore_kernel.h diff --git a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh b/src/infiniop/ops/embedding/cuda/embedding_kernel.cuh similarity index 59% rename from src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh rename to src/infiniop/ops/embedding/cuda/embedding_kernel.cuh index 0e85b5f6a..2914f06ed 100644 --- a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh +++ b/src/infiniop/ops/embedding/cuda/embedding_kernel.cuh @@ -1,13 +1,8 @@ #ifndef __EMBEDDING_CUDA_KERNEL_CUH__ #define __EMBEDDING_CUDA_KERNEL_CUH__ -#include "../../../devices/nvidia/nvidia_kernel_common.cuh" -#include -#include #include -namespace op::embedding::nvidia { - // Helper function to check memory alignment __forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) { // Use size_t for pointer arithmetic in device code (more compatible) @@ -118,61 +113,4 @@ __forceinline__ __device__ void copyScalar( } } -template -INFINIOP_CUDA_KERNEL embeddingKernel( - T *__restrict__ output, - const IndexType *__restrict__ indices, - const T *__restrict__ weight, - size_t num_indices, - size_t embedding_dim, - size_t vocab_size) { - // Calculate global thread index - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx < num_indices) { - // Get the index value - IndexType index_val = __ldg(&indices[idx]); - - // Bounds check - handle negative indices gracefully - if (index_val >= 0 && static_cast(index_val) < vocab_size) { - // Copy embedding vector from weight to output - const T *src = weight + static_cast(index_val) * embedding_dim; - T *dst = output + idx * embedding_dim; - - // Choose optimal copy strategy based on type and alignment - if constexpr (std::is_same_v) { - // Check alignment for float4 (16 bytes) - bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16); - if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) { - copyVectorizedFloat4(dst, src, embedding_dim); - } else if (embedding_dim >= 2 && embedding_dim % 2 == 0) { - // Try float2 if not aligned to 16 bytes - copyVectorizedFloat2(dst, src, embedding_dim); - } else { - copyScalar(dst, src, embedding_dim); - } - } else if constexpr (std::is_same_v) { - // Use half2 for vectorized access - if (embedding_dim >= 2 && embedding_dim % 2 == 0) { - copyVectorizedHalf2(dst, src, embedding_dim); - } else { - copyScalar(dst, src, embedding_dim); - } - } else if constexpr (std::is_same_v) { - // Use bfloat162 for vectorized access - if (embedding_dim >= 2 && embedding_dim % 2 == 0) { - copyVectorizedBFloat162(dst, src, embedding_dim); - } else { - copyScalar(dst, src, embedding_dim); - } - } else { - // Fallback to scalar copy with __ldg - copyScalar(dst, src, embedding_dim); - } - } - } -} - -} // namespace op::embedding::nvidia - #endif // __EMBEDDING_CUDA_KERNEL_CUH__ diff --git a/src/infiniop/ops/embedding/metax/embedding_metax.cuh b/src/infiniop/ops/embedding/metax/embedding_metax.cuh new file mode 100644 index 000000000..7290fc918 --- /dev/null +++ b/src/infiniop/ops/embedding/metax/embedding_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_METAX_H__ +#define __EMBEDDING_METAX_H__ + +#include "../embedding.h" + +DESCRIPTOR(metax) + +#endif // __EMBEDDING_METAX_H__ diff --git a/src/infiniop/ops/embedding/metax/embedding_metax.maca b/src/infiniop/ops/embedding/metax/embedding_metax.maca new file mode 100644 index 000000000..8a1b24ea0 --- /dev/null +++ b/src/infiniop/ops/embedding/metax/embedding_metax.maca @@ -0,0 +1,217 @@ +#include "../../../../utils.h" +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../../../tensor.h" +#include "../cuda/embedding_kernel.cuh" +#include "embedding_metax.cuh" + +template +INFINIOP_METAX_KERNEL embeddingKernel( + T *__restrict__ output, + const IndexType *__restrict__ indices, + const T *__restrict__ weight, + size_t num_indices, + size_t embedding_dim, + size_t vocab_size) { + // Calculate global thread index + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < num_indices) { + // Get the index value + IndexType index_val = __ldg(&indices[idx]); + + // Bounds check - handle negative indices gracefully + if (index_val >= 0 && static_cast(index_val) < vocab_size) { + // Copy embedding vector from weight to output + const T *src = weight + static_cast(index_val) * embedding_dim; + T *dst = output + idx * embedding_dim; + + // Choose optimal copy strategy based on type and alignment + if constexpr (std::is_same_v) { + // Check alignment for float4 (16 bytes) + bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16); + if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) { + copyVectorizedFloat4(dst, src, embedding_dim); + } else if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + // Try float2 if not aligned to 16 bytes + copyVectorizedFloat2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use half2 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedHalf2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use bfloat162 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedBFloat162(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else { + // Fallback to scalar copy with __ldg + copyScalar(dst, src, embedding_dim); + } + } + } +} + +namespace op::embedding::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto input_shape = input_desc->shape(); + auto weight_shape = weight_desc->shape(); + + // Validate shapes + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + // Check output shape matches input shape + embedding_dim + auto output_shape = output_desc->shape(); + size_t embedding_dim = weight_shape[1]; + CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); + + for (size_t i = 0; i < input_shape.size(); ++i) { + CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + // Validate dtypes + auto input_dtype = input_desc->dtype(); + auto weight_dtype = weight_desc->dtype(); + CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || + weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + // Calculate number of indices (supporting batch dimension) + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + + size_t vocab_size = weight_shape[0]; + + *desc_ptr = new Descriptor( + num_indices, + embedding_dim, + vocab_size, + input_dtype, + weight_dtype, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + + if (_num_indices == 0) { + return INFINI_STATUS_SUCCESS; + } + + auto hc_stream = reinterpret_cast(stream); + + // Dynamic block size optimization based on embedding_dim for Metax platform + size_t block_size = 256; // Default block size for Metax + if (_embedding_dim <= 64) { + block_size = 512; // Small embedding_dim: use larger block for better occupancy + } else if (_embedding_dim >= 1024) { + block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure + } + + size_t grid_size = (_num_indices + block_size - 1) / block_size; + + // Launch kernel based on dtypes for Metax platform + if (_input_dtype == INFINI_DTYPE_I32) { + const int32_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + // Use Metax's bfloat16 type + embeddingKernel<__hpcc_bfloat16, int32_t><<>>( + reinterpret_cast<__hpcc_bfloat16 *>(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_input_dtype == INFINI_DTYPE_I64) { + const int64_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + embeddingKernel<__hpcc_bfloat16, int64_t><<>>( + reinterpret_cast<__hpcc_bfloat16 *>(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::embedding::metax diff --git a/src/infiniop/ops/embedding/moore/embedding_moore.h b/src/infiniop/ops/embedding/moore/embedding_moore.h new file mode 100644 index 000000000..edc397be6 --- /dev/null +++ b/src/infiniop/ops/embedding/moore/embedding_moore.h @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_MOORE_H__ +#define __EMBEDDING_MOORE_H__ + +#include "../embedding.h" + +DESCRIPTOR(moore) + +#endif // __EMBEDDING_MOORE_H__ diff --git a/src/infiniop/ops/embedding/moore/embedding_moore.mu b/src/infiniop/ops/embedding/moore/embedding_moore.mu new file mode 100644 index 000000000..147ac830f --- /dev/null +++ b/src/infiniop/ops/embedding/moore/embedding_moore.mu @@ -0,0 +1,227 @@ +#include "../../../../utils.h" +#include "../../../devices/moore/moore_common.h" +#include "../../../devices/moore/moore_kernel_common.h" +#include "../../../tensor.h" +#include "embedding_moore_kernel.h" +#include "embedding_moore.h" +#include + +template +INFINIOP_MOORE_KERNEL embeddingKernel( + T *__restrict__ output, + const IndexType *__restrict__ indices, + const T *__restrict__ weight, + size_t num_indices, + size_t embedding_dim, + size_t vocab_size) { + // Calculate global thread index + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < num_indices) { + // Get the index value with Moore-optimized memory access + IndexType index_val = indices[idx]; + + // Bounds check - handle negative indices gracefully + if (index_val >= 0 && static_cast(index_val) < vocab_size) { + // Copy embedding vector from weight to output + const T *src = weight + static_cast(index_val) * embedding_dim; + T *dst = output + idx * embedding_dim; + + // Choose optimal copy strategy based on type and alignment + if constexpr (std::is_same_v) { + // Check alignment for float4 (16 bytes) + bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16); + if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) { + copyVectorizedFloat4(dst, src, embedding_dim); + } else if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + // Try float2 if not aligned to 16 bytes + copyVectorizedFloat2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use half2 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedHalf2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use mt_bfloat162 for vectorized access (Moore-specific type) + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedBFloat162(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else { + // Fallback to scalar copy with Moore-optimized memory access + copyScalar(dst, src, embedding_dim); + } + } + } +} + +namespace op::embedding::moore { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto input_shape = input_desc->shape(); + auto weight_shape = weight_desc->shape(); + + // Validate shapes + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + // Check output shape matches input shape + embedding_dim + auto output_shape = output_desc->shape(); + size_t embedding_dim = weight_shape[1]; + CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); + + for (size_t i = 0; i < input_shape.size(); ++i) { + CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + // Validate dtypes + auto input_dtype = input_desc->dtype(); + auto weight_dtype = weight_desc->dtype(); + CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + // Calculate number of indices (supporting batch dimension) + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + + size_t vocab_size = weight_shape[0]; + + *desc_ptr = new Descriptor( + num_indices, + embedding_dim, + vocab_size, + input_dtype, + weight_dtype, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + + if (_num_indices == 0) { + return INFINI_STATUS_SUCCESS; + } + + auto musa_stream = reinterpret_cast(stream); + + // Dynamic block size optimization based on embedding_dim + // Moore platform typically has different performance characteristics + size_t block_size = 256; // Default for Moore + if (_embedding_dim <= 64) { + block_size = 512; // Small embedding_dim: use larger block for better occupancy + } else if (_embedding_dim >= 1024) { + block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure + } else if (_embedding_dim <= 256) { + block_size = 384; // Medium embedding_dim: balanced configuration + } + + size_t grid_size = (_num_indices + block_size - 1) / block_size; + + // Launch kernel based on dtypes + // Note: Moore uses __mt_bfloat16 instead of __nv_bfloat16 + if (_input_dtype == INFINI_DTYPE_I32) { + const int32_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + // Use Moore's bfloat16 type + embeddingKernel<__mt_bfloat16, int32_t><<>>( + reinterpret_cast<__mt_bfloat16 *>(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_input_dtype == INFINI_DTYPE_I64) { + const int64_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + embeddingKernel<__mt_bfloat16, int64_t><<>>( + reinterpret_cast<__mt_bfloat16 *>(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // Check for kernel launch errors + musaError_t err = musaGetLastError(); + if (err != musaSuccess) { + return INFINI_STATUS_INTERNAL_ERROR; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::embedding::moore diff --git a/src/infiniop/ops/embedding/moore/embedding_moore_kernel.h b/src/infiniop/ops/embedding/moore/embedding_moore_kernel.h new file mode 100644 index 000000000..9a7427b05 --- /dev/null +++ b/src/infiniop/ops/embedding/moore/embedding_moore_kernel.h @@ -0,0 +1,116 @@ +#ifndef __EMBEDDING_MOORE_KERNEL_CUH__ +#define __EMBEDDING_MOORE_KERNEL_CUH__ + +#include + +// Helper function to check memory alignment +__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) { + // Use size_t for pointer arithmetic in device code (more compatible) + return (reinterpret_cast(ptr) % alignment == 0); +} + +// Vectorized copy for float type using float4 +template +__forceinline__ __device__ void copyVectorizedFloat4( + float *__restrict__ dst, + const float *__restrict__ src, + size_t embedding_dim) { + // Use float4 for vectorized access (16 bytes, 4 floats) + const float4 *src_vec = reinterpret_cast(src); + float4 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 4; + + // Vectorized copy with __ldg equivalent for Moore platform + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = src_vec[i]; + } + + // Copy remaining elements + size_t remaining = embedding_dim % 4; + if (remaining > 0) { + size_t offset = vec_count * 4; + for (size_t i = 0; i < remaining; ++i) { + dst[offset + i] = src[offset + i]; + } + } +} + +// Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes) +template +__forceinline__ __device__ void copyVectorizedFloat2( + float *__restrict__ dst, + const float *__restrict__ src, + size_t embedding_dim) { + // Use float2 for vectorized access (8 bytes, 2 floats) + const float2 *src_vec = reinterpret_cast(src); + float2 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy with Moore-optimized memory access + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = src_vec[i]; + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = src[embedding_dim - 1]; + } +} + +// Vectorized copy for half type using half2 +template +__forceinline__ __device__ void copyVectorizedHalf2( + half *__restrict__ dst, + const half *__restrict__ src, + size_t embedding_dim) { + // Use half2 for vectorized access (4 bytes, 2 halfs) + const half2 *src_vec = reinterpret_cast(src); + half2 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy optimized for Moore architecture + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = src_vec[i]; + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = src[embedding_dim - 1]; + } +} + +// Vectorized copy for Moore bfloat16 type using bfloat162 +template +__forceinline__ __device__ void copyVectorizedBFloat162( + __mt_bfloat16 *__restrict__ dst, + const __mt_bfloat16 *__restrict__ src, + size_t embedding_dim) { + // Use mt_bfloat162 for vectorized access (4 bytes, 2 bfloat16s) + const __mt_bfloat162 *src_vec = reinterpret_cast(src); + __mt_bfloat162 *dst_vec = reinterpret_cast<__mt_bfloat162 *>(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy with Moore-specific optimization + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = src_vec[i]; + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = src[embedding_dim - 1]; + } +} + +// Scalar copy fallback with Moore-optimized memory access +template +__forceinline__ __device__ void copyScalar( + T *__restrict__ dst, + const T *__restrict__ src, + size_t embedding_dim) { + // Scalar copy with Moore read-only weight optimization + for (size_t i = 0; i < embedding_dim; ++i) { + dst[i] = src[i]; + } +} + +#endif // __EMBEDDING_MOORE_KERNEL_CUH__ diff --git a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu index b714b0aa4..8414e187e 100644 --- a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu +++ b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu @@ -2,10 +2,65 @@ #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "../../../tensor.h" -#include "embedding_kernel.cuh" +#include "../cuda/embedding_kernel.cuh" #include "embedding_nvidia.cuh" #include +template +INFINIOP_CUDA_KERNEL embeddingKernel( + T *__restrict__ output, + const IndexType *__restrict__ indices, + const T *__restrict__ weight, + size_t num_indices, + size_t embedding_dim, + size_t vocab_size) { + // Calculate global thread index + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < num_indices) { + // Get the index value + IndexType index_val = __ldg(&indices[idx]); + + // Bounds check - handle negative indices gracefully + if (index_val >= 0 && static_cast(index_val) < vocab_size) { + // Copy embedding vector from weight to output + const T *src = weight + static_cast(index_val) * embedding_dim; + T *dst = output + idx * embedding_dim; + + // Choose optimal copy strategy based on type and alignment + if constexpr (std::is_same_v) { + // Check alignment for float4 (16 bytes) + bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16); + if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) { + copyVectorizedFloat4(dst, src, embedding_dim); + } else if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + // Try float2 if not aligned to 16 bytes + copyVectorizedFloat2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use half2 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedHalf2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use bfloat162 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedBFloat162(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else { + // Fallback to scalar copy with __ldg + copyScalar(dst, src, embedding_dim); + } + } + } +} + namespace op::embedding::nvidia { struct Descriptor::Opaque { diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc index 50f2f05ed..09cd1f737 100644 --- a/src/infiniop/ops/embedding/operator.cc +++ b/src/infiniop/ops/embedding/operator.cc @@ -8,6 +8,12 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #include "nvidia/embedding_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/embedding_metax.cuh" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/embedding_moore.h" +#endif __C infiniStatus_t infiniopCreateEmbeddingDescriptor( infiniopHandle_t handle, @@ -30,18 +36,24 @@ __C infiniStatus_t infiniopCreateEmbeddingDescriptor( #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif -#if defined(ENABLE_ILUVATAR_API) +#ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif -#if defined(ENABLE_QY_API) +#ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif -#if defined(ENABLE_HYGON_API) +#ifdef ENABLE_HYGON_API CREATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -67,18 +79,24 @@ __C infiniStatus_t infiniopEmbedding( #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif -#if defined(ENABLE_ILUVATAR_API) +#ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif -#if defined(ENABLE_QY_API) +#ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif -#if defined(ENABLE_HYGON_API) +#ifdef ENABLE_HYGON_API CALCULATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -89,30 +107,39 @@ __C infiniStatus_t infiniopEmbedding( __C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { -#define DELETE(CASE, NAMESPACE) \ +#define DESTROY(CASE, NAMESPACE) \ case CASE: \ delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; switch (desc->device_type) { #ifdef ENABLE_CPU_API - DELETE(INFINI_DEVICE_CPU, cpu); + DESTROY(INFINI_DEVICE_CPU, cpu); #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + DESTROY(INFINI_DEVICE_QY, nvidia); #endif -#if defined(ENABLE_ILUVATAR_API) - DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#ifdef ENABLE_HYGON_API + DESTROY(INFINI_DEVICE_HYGON, nvidia); #endif -#if defined(ENABLE_QY_API) - DELETE(INFINI_DEVICE_QY, nvidia); +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); #endif -#if defined(ENABLE_HYGON_API) - DELETE(INFINI_DEVICE_HYGON, nvidia); +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); #endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } -#undef DELETE +#undef DESTROY return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } From eb34d4d6490f18794d954c7fef21c9b7dc60217b Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 9 Jan 2026 17:25:57 +0800 Subject: [PATCH 05/25] issue/900 - adapt to graph and adjust test script --- include/infinicore/ops/embedding.hpp | 13 +- src/infinicore/ops/embedding/embedding.cc | 26 +- .../ops/embedding/embedding_infiniop.cc | 63 ++-- .../graph/test_embedding_graph_recording.py | 291 ++++++++++++++++ .../EMBEDDING_GRAPH_RECORDING_COMPARISON.md | 159 --------- .../nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md | 317 ------------------ .../nn/test_embedding_graph_recording.py | 284 ---------------- 7 files changed, 334 insertions(+), 819 deletions(-) create mode 100644 test/infinicore/graph/test_embedding_graph_recording.py delete mode 100644 test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md delete mode 100644 test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md delete mode 100644 test/infinicore/nn/test_embedding_graph_recording.py diff --git a/include/infinicore/ops/embedding.hpp b/include/infinicore/ops/embedding.hpp index 6be997134..43f18d090 100644 --- a/include/infinicore/ops/embedding.hpp +++ b/include/infinicore/ops/embedding.hpp @@ -1,16 +1,13 @@ #pragma once +#include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Embedding { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor out, Tensor input, Tensor weight); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(Embedding, Tensor, const Tensor &, const Tensor &); -Tensor embedding(Tensor input, Tensor weight); -void embedding_(Tensor out, Tensor input, Tensor weight); +Tensor embedding(const Tensor &input, const Tensor &weight); +void embedding_(Tensor out, const Tensor &input, const Tensor &weight); } // namespace infinicore::op diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index 96f19803c..2dfd3aa21 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -5,27 +5,19 @@ #include namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding); -common::OpDispatcher &Embedding::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -} - -void Embedding::execute(Tensor out, Tensor input, Tensor weight) { - // Check that all tensors are on the same device - // This is critical: if input is on CPU while out/weight are on GPU, - // passing CPU pointer to CUDA kernel will cause memory access errors +Embedding::Embedding(Tensor out, const Tensor &input, const Tensor &weight) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, input, weight); +} - // Set device context - infinicore::context::setDevice(out->device()); - - // Use dispatcher to lookup kernel (infiniop implementation) - dispatcher().lookup(out->device().getType())(out, input, weight); +void Embedding::execute(Tensor out, const Tensor &input, const Tensor &weight) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Embedding, out, input, weight); } -Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract - Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 +Tensor embedding(const Tensor &input, // LongTensor of arbitrary shape containing the indices to extract + const Tensor &weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 ) { auto input_shape = input->shape(); auto weight_shape = weight->shape(); @@ -40,7 +32,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i return inputs_embeds; } -void embedding_(Tensor out, Tensor input, Tensor weight) { +void embedding_(Tensor out, const Tensor &input, const Tensor &weight) { Embedding::execute(out, input, weight); } diff --git a/src/infinicore/ops/embedding/embedding_infiniop.cc b/src/infinicore/ops/embedding/embedding_infiniop.cc index dfbbb2f71..a9780d3ae 100644 --- a/src/infinicore/ops/embedding/embedding_infiniop.cc +++ b/src/infinicore/ops/embedding/embedding_infiniop.cc @@ -1,49 +1,44 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" +#include "../infiniop_impl.hpp" #include "infinicore/ops/embedding.hpp" -#include namespace infinicore::op::embedding_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopEmbeddingDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Embedding, 100); -void calculate(Tensor out, Tensor input, Tensor weight) { +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor out, input, weight; +}; + +void *plan(Tensor out, const Tensor &input, const Tensor &weight) { size_t seed = hash_combine(out, input, weight); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Embedding, + seed, out->desc(), input->desc(), weight->desc()); + + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(out), + graph::GraphTensor(input), + graph::GraphTensor(weight)}; - auto desc_opt = cache.get(seed); - infiniopEmbeddingDescriptor_t desc = nullptr; + return planned; +} - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor( - context::getInfiniopHandle(device), &desc, - out->desc(), input->desc(), weight->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopEmbedding( - desc, - out->data(), - input->data(), - weight->data(), - context::getStream())); + planned->descriptor->desc, + planned->out->data(), planned->input->data(), planned->weight->data(), context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Embedding::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Embedding, &plan, &run, cleanup); } // namespace infinicore::op::embedding_impl::infiniop diff --git a/test/infinicore/graph/test_embedding_graph_recording.py b/test/infinicore/graph/test_embedding_graph_recording.py new file mode 100644 index 000000000..3795c3ae5 --- /dev/null +++ b/test/infinicore/graph/test_embedding_graph_recording.py @@ -0,0 +1,291 @@ +""" +Test if embedding supports CUDA Graph recording + +Usage: + python test/infinicore/nn/test_embedding_graph_recording.py + +Key verification points: +1. Before modification: indices->to(cpu_device) triggers synchronous D2H copy, causing graph recording to fail +2. After modification: Uses device-side CUDA kernel, fully asynchronous, supports graph recording + +Expected results: +- Before modification: Graph recording fails, device-side input may fail +- After modification: Graph recording succeeds, device-side input succeeds +""" + +import infinicore +import torch + + +def test_embedding_graph_recording(): + """Test if embedding supports CUDA Graph recording""" + print("=" * 60) + print("Testing Embedding Graph Recording Support") + print("=" * 60) + + # Check if CUDA is available + if not torch.cuda.is_available(): + print("⚠ CUDA not available, skipping graph recording test") + return False + + device = infinicore.device("cuda", 0) + + # Create embedding module + vocab_size = 1000 + embedding_dim = 128 + embedding = infinicore.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + dtype=infinicore.float32, + device=device, + ) + + # Create device-side input_ids (key point: unsupported before modification, supported after) + batch_size = 4 + seq_len = 32 + input_ids_device = infinicore.from_list( + [[i % vocab_size for i in range(seq_len)] for _ in range(batch_size)], + dtype=infinicore.int64, + device=device, + ) + + print(f"\n1. Input tensor information:") + print(f" - Shape: {input_ids_device.shape}") + print(f" - Device: {input_ids_device.device.type}") + print(f" - Dtype: {input_ids_device.dtype}") + + # Attempt CUDA Graph recording + print(f"\n2. Attempting CUDA Graph recording...") + + # Use PyTorch's CUDA Graph API for testing (simpler and more reliable) + try: + # Set device + infinicore.set_device(device) + + # Use PyTorch's CUDA Graph API + # Note: PyTorch 2.0+ supports torch.cuda.graph + try: + # Method 1: Use PyTorch CUDA Graph (recommended) + print(" Using PyTorch CUDA Graph API for testing...") + + # Create warmup input + warmup_input = input_ids_device + + # Warmup (need to execute once before graph recording, including memory allocation) + embedding.forward(warmup_input) + infinicore.sync_stream() # Synchronize to ensure warmup completes + + # Pre-allocate output tensor (CUDA Graph doesn't support dynamic memory allocation) + # Output shape: input_shape + [embedding_dim] + output_shape = list(input_ids_device.shape) + [embedding_dim] + output = infinicore.empty( + output_shape, dtype=embedding.weight.dtype, device=device + ) + + # Warmup embedding (ensure memory allocation is complete) + import infinicore.nn.functional as F + + F.embedding(warmup_input, embedding.weight, out=output) + infinicore.sync_stream() + + # Start graph recording (using pre-allocated output) + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + # Use embedding's out parameter (in-place), passing pre-allocated output + F.embedding(input_ids_device, embedding.weight, out=output) + + print(" ✓ Graph recording successful!") + print(" ✓ Embedding supports CUDA Graph recording") + + # Verify graph can be replayed + graph.replay() + infinicore.sync_stream() + + print(" ✓ Graph can be successfully replayed") + return True + + except AttributeError: + # PyTorch version may not support torch.cuda.graph + print( + " ⚠ PyTorch version doesn't support torch.cuda.graph, using simplified verification method" + ) + return test_embedding_async_verification(embedding, input_ids_device) + except RuntimeError as e: + error_msg = str(e) + if "capture" in error_msg.lower() or "graph" in error_msg.lower(): + print(f" ✗ Graph recording failed: {e}") + print( + " ✗ Embedding doesn't support CUDA Graph recording (may contain synchronous operations)" + ) + return False + else: + print(f" ⚠ Graph recording test exception: {e}") + return test_embedding_async_verification(embedding, input_ids_device) + + except Exception as e: + print(f" ⚠ Graph recording test exception: {e}") + print(" Using simplified verification method...") + import traceback + + traceback.print_exc() + return test_embedding_async_verification(embedding, input_ids_device) + + +def test_embedding_async_verification(embedding, input_ids_device): + """ + Simplified verification: Check if there are synchronous operations + + Key checkpoints: + 1. Whether input can be on device (needed CPU before modification, supports device after) + 2. Whether operations are fully asynchronous (no synchronization points) + """ + print("\n3. Simplified verification: Checking asynchronous operation support") + + # Verification 1: Input can be on device + if input_ids_device.device.type != "cuda": + print(" ✗ Input not on device, cannot verify") + return False + + print(" ✓ Input is on device") + + # Verification 2: Execute forward, check for synchronous operations + # Before modification, this would call indices->to(cpu_device), triggering synchronization + # After modification, directly uses device-side kernel, fully asynchronous + + try: + # Record start time + start_event = infinicore.DeviceEvent(enable_timing=True) + end_event = infinicore.DeviceEvent(enable_timing=True) + + start_event.record() + output = embedding.forward(input_ids_device) + end_event.record() + + # Don't synchronize immediately, check if operation is asynchronous + # If operation is asynchronous, query should return False (not completed) + # If operation is synchronous, may have already completed + + # Wait a short time + import time + + time.sleep(0.001) # 1ms + + # Check event status + is_complete = end_event.query() + + if not is_complete: + print(" ✓ Operation is asynchronous (event not immediately completed)") + else: + print( + " ⚠ Operation may contain synchronization points (event immediately completed)" + ) + + # Synchronize and measure time + end_event.synchronize() + elapsed = start_event.elapsed_time(end_event) + + print(f" ✓ Forward execution time: {elapsed:.3f} ms") + print(f" ✓ Output shape: {output.shape}") + print(f" ✓ Output device: {output.device.type}") + + # Verify output correctness + embedding_dim = embedding.embedding_dim() + expected_shape = (*input_ids_device.shape, embedding_dim) + if output.device.type == "cuda" and output.shape == expected_shape: + print(" ✓ Output on device, shape correct") + return True + else: + print(f" ✗ Output verification failed") + print( + f" Expected shape: {expected_shape}, actual shape: {output.shape}" + ) + print(f" Expected device: cuda, actual device: {output.device.type}") + return False + + except Exception as e: + print(f" ✗ Verification failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_embedding_device_input_support(): + """Test if embedding supports device-side input""" + print("\n" + "=" * 60) + print("Testing Embedding Device-side Input Support") + print("=" * 60) + + if not torch.cuda.is_available(): + print("⚠ CUDA not available, skipping test") + return False + + device = infinicore.device("cuda", 0) + vocab_size = 100 + embedding_dim = 64 + + embedding = infinicore.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + dtype=infinicore.float32, + device=device, + ) + + # Test 1: Device-side input (supported after modification) + print("\nTest 1: Device-side input") + try: + input_ids_device = infinicore.from_list( + [[1, 2, 3, 4, 5]], dtype=infinicore.int64, device=device + ) + output = embedding.forward(input_ids_device) + print(f" ✓ Device-side input successful") + print(f" - Input device: {input_ids_device.device.type}") + print(f" - Output device: {output.device.type}") + print(f" - Output shape: {output.shape}") + return True + except Exception as e: + print(f" ✗ Device-side input failed: {e}") + return False + + +def main(): + """Main test function""" + print("\n" + "=" * 60) + print("Embedding Graph Recording Support Verification") + print("=" * 60) + + results = [] + + # Test 1: Graph recording support + result1 = test_embedding_graph_recording() + results.append(("CUDA Graph Recording", result1)) + + # Test 2: Device-side input support + result2 = test_embedding_device_input_support() + results.append(("Device-side Input", result2)) + + # Summary + print("\n" + "=" * 60) + print("Test Results Summary") + print("=" * 60) + + all_passed = True + for test_name, result in results: + status = "✓ Passed" if result else "✗ Failed" + print(f"{test_name}: {status}") + if not result: + all_passed = False + + print("\n" + "=" * 60) + if all_passed: + print("✓ All tests passed! Embedding supports graph recording") + else: + print("✗ Some tests failed, embedding may not fully support graph recording") + print("=" * 60) + + return all_passed + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md b/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md deleted file mode 100644 index 686c10a1b..000000000 --- a/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md +++ /dev/null @@ -1,159 +0,0 @@ -# Embedding 图录制支持对比 - -## 改动前后对比 - -### ❌ 改动前:不支持图录制 - -**关键问题代码**(在 `nn::Embedding::forward` 中): -```cpp -// 改动前的实现 -Tensor Embedding::forward(const Tensor &indices) const { - auto cpu_device = Device(Device::Type::CPU, 0); - auto indices_cpu = indices->to(cpu_device)->contiguous(); // ❌ 同步操作! - - // ... 后续处理 -} -``` - -**问题分析**: -1. `indices->to(cpu_device)` 会触发 **同步的 D2H(Device-to-Host)内存拷贝** -2. CUDA Graph 录制要求所有操作都是**异步的**,不能有同步点 -3. 同步操作会导致图录制失败或产生错误 - -**验证方法**: -```python -# 改动前:这个操作会失败或产生同步 -input_ids_device = infinicore.from_list(..., device="cuda:0") # 设备端输入 -output = embedding.forward(input_ids_device) # ❌ 内部会同步拷贝到 CPU -``` - ---- - -### ✅ 改动后:支持图录制 - -**关键改进代码**: -```cpp -// 改动后的实现 -Tensor Embedding::forward(const Tensor &indices) const { - Tensor indices_contiguous = indices->is_contiguous() ? indices : indices->contiguous(); - return op::embedding(indices_contiguous, weight_); // ✅ 直接使用设备端 kernel -} -``` - -**改进点**: -1. **移除了同步操作**:不再调用 `indices->to(cpu_device)` -2. **使用设备端 CUDA kernel**:通过 InfiniOP 调用 `embeddingKernel`,完全在设备端执行 -3. **完全异步**:所有操作都在 CUDA stream 上异步执行 - -**实现位置**: -- CUDA Kernel: `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` -- Kernel 启动:使用 `cudaStream_t`,完全异步 -- 无同步点:没有 `cudaDeviceSynchronize()` 或 D2H 拷贝 - -**验证方法**: -```python -# 改动后:这个操作完全异步,支持图录制 -input_ids_device = infinicore.from_list(..., device="cuda:0") # 设备端输入 -output = embedding.forward(input_ids_device) # ✅ 直接使用设备端 kernel,无同步 -``` - ---- - -## 验证方法 - -### 方法 1: 代码检查 - -**检查点**: -1. ✅ 是否有 `->to(cpu_device)` 调用? -2. ✅ 是否有 `synchronize()` 调用? -3. ✅ 是否有设备端 kernel 实现? - -**改动前**: -```cpp -// ❌ 有同步操作 -auto indices_cpu = indices->to(cpu_device)->contiguous(); -``` - -**改动后**: -```cpp -// ✅ 无同步操作,直接使用设备端 kernel -return op::embedding(indices_contiguous, weight_); -``` - -### 方法 2: CUDA Graph API 测试 - -运行测试脚本: -```bash -python test/infinicore/nn/test_embedding_graph_recording.py -``` - -**预期结果**: -- ✅ 改动后:图录制成功 -- ❌ 改动前:图录制失败(因为同步操作) - -### 方法 3: 设备端输入测试 - -**关键测试**: -```python -# 创建设备端输入 -input_ids = infinicore.from_list([[1, 2, 3]], dtype=int64, device="cuda:0") - -# 执行 forward -output = embedding.forward(input_ids) # 改动前会失败或同步,改动后成功 -``` - -**改动前**: -- 需要先将 `input_ids` 拷贝到 CPU -- 触发同步操作,无法图录制 - -**改动后**: -- 直接使用设备端 `input_ids` -- 完全异步,支持图录制 - ---- - -## 技术细节对比 - -| 特性 | 改动前 | 改动后 | -|------|--------|--------| -| **输入设备** | 必须在 CPU | 支持设备端 | -| **同步操作** | ❌ 有(D2H拷贝) | ✅ 无 | -| **Kernel位置** | CPU 实现 | CUDA kernel | -| **图录制支持** | ❌ 不支持 | ✅ 支持 | -| **Batch维度** | ✅ 支持 | ✅ 支持 | -| **性能** | 较慢(同步开销) | 更快(异步) | - ---- - -## 关键代码位置 - -### 改动前的问题代码 -- `src/infinicore/nn/embedding.cc` (旧版本) - - 第58行:`indices->to(cpu_device)->contiguous()` ❌ - -### 改动后的实现 -- `src/infinicore/nn/embedding.cc` (新版本) - - 第48行:`indices->is_contiguous() ? indices : indices->contiguous()` ✅ - - 第52行:`return op::embedding(indices_contiguous, weight_)` ✅ - -- `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` - - CUDA kernel 实现,完全异步 ✅ - -- `src/infinicore/ops/embedding/embedding_infiniop.cc` - - InfiniOP 包装,调用设备端 kernel ✅ - ---- - -## 总结 - -**改动前的关键问题**: -- ❌ `indices->to(cpu_device)` 触发同步 D2H 拷贝 -- ❌ 无法进行 CUDA Graph 录制 -- ❌ 性能较差(同步开销) - -**改动后的改进**: -- ✅ 移除所有同步操作 -- ✅ 使用设备端 CUDA kernel -- ✅ 完全支持 CUDA Graph 录制 -- ✅ 性能更好(完全异步) - diff --git a/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md b/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md deleted file mode 100644 index e5e60db2b..000000000 --- a/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md +++ /dev/null @@ -1,317 +0,0 @@ -# Embedding 图录制测试使用指南 - -## 🚀 快速开始 - -### 运行测试 - -```bash -cd /home/zhuyue/codes/InfiniCore -python test/infinicore/nn/test_embedding_graph_recording.py -``` - ---- - -## 📊 改动前后对比 - -### ❌ 改动前:不支持图录制 - -#### 1. 运行测试 - -```bash -python test/infinicore/nn/test_embedding_graph_recording.py -``` - -#### 2. 预期输出 - -``` -============================================================ -Embedding 图录制支持验证 -============================================================ -============================================================ -测试 Embedding 图录制支持 -============================================================ - -1. 输入张量信息: - - Shape: [4, 32] - - Device: cuda - - Dtype: int64 - -2. 尝试 CUDA Graph 录制... - 使用 PyTorch CUDA Graph API 测试... - ✗ 图录制失败: [错误信息] - ✗ Embedding 不支持 CUDA Graph 录制(可能包含同步操作) - -3. 简化验证:检查异步操作支持 - ✓ 输入在设备上 - ⚠ 操作可能包含同步点(事件立即完成) ← 关键:说明有同步操作 - ✓ Forward 执行时间: X.XXX ms - ✓ 输出形状: [4, 32, 128] - ✓ 输出设备: cuda - ✗ 输出验证失败 - -============================================================ -测试 Embedding 设备端输入支持 -============================================================ - -测试 1: 设备端输入 - ✗ 设备端输入失败: [错误信息] - -============================================================ -测试结果总结 -============================================================ -CUDA Graph 录制: ✗ 失败 -设备端输入: ✗ 失败 -============================================================ -✗ 部分测试失败,Embedding 可能不完全支持图录制 -============================================================ -``` - -#### 3. 关键失败点 - -- **图录制失败**:因为代码中有 `indices->to(cpu_device)` 同步操作 -- **设备端输入失败**:需要先将输入拷贝到 CPU -- **异步验证显示同步点**:事件立即完成,说明有同步操作 - ---- - -### ✅ 改动后:支持图录制 - -#### 1. 运行测试 - -```bash -python test/infinicore/nn/test_embedding_graph_recording.py -``` - -#### 2. 预期输出 - -``` -============================================================ -Embedding 图录制支持验证 -============================================================ -============================================================ -测试 Embedding 图录制支持 -============================================================ - -1. 输入张量信息: - - Shape: [4, 32] - - Device: cuda - - Dtype: int64 - -2. 尝试 CUDA Graph 录制... - 使用 PyTorch CUDA Graph API 测试... - ✓ 成功完成图录制! - ✓ Embedding 支持 CUDA Graph 录制 - ✓ 图可以成功重放 - -============================================================ -测试 Embedding 设备端输入支持 -============================================================ - -测试 1: 设备端输入 - ✓ 设备端输入成功 - - 输入设备: cuda - - 输出设备: cuda - - 输出形状: [1, 5, 64] - -============================================================ -测试结果总结 -============================================================ -CUDA Graph 录制: ✓ 通过 -设备端输入: ✓ 通过 -============================================================ -✓ 所有测试通过!Embedding 支持图录制 -============================================================ -``` - -#### 3. 关键成功点 - -- **图录制成功**:所有操作都是异步的,无同步点 -- **设备端输入成功**:直接支持设备端输入,无需拷贝 -- **图可以重放**:验证图录制的正确性 - ---- - -## 🔍 如何判断当前是改动前还是改动后? - -### 方法 1: 代码检查(最快) - -```bash -# 检查是否有同步操作 -grep -n "to(cpu_device)" src/infinicore/nn/embedding.cc - -# 结果解读: -# - 有输出 → ❌ 改动前(不支持图录制) -# - 无输出 → ✅ 改动后(支持图录制) -``` - -### 方法 2: 检查设备端实现 - -```bash -# 检查是否有设备端 CUDA kernel -ls src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu - -# 结果解读: -# - 不存在 → ❌ 改动前(不支持图录制) -# - 存在 → ✅ 改动后(支持图录制) -``` - -### 方法 3: 运行测试(最准确) - -```bash -python test/infinicore/nn/test_embedding_graph_recording.py - -# 查看 "CUDA Graph 录制" 测试结果: -# - ✓ 通过 → ✅ 改动后(支持图录制) -# - ✗ 失败 → ❌ 改动前(不支持图录制) -``` - ---- - -## 📝 测试内容详解 - -### 测试 1: CUDA Graph 录制 - -**目的**:验证 embedding 是否可以在 CUDA Graph 中录制 - -**工作原理**: -1. 使用 PyTorch 的 `torch.cuda.CUDAGraph()` API -2. 在图录制模式下执行 `embedding.forward()` -3. 如果包含同步操作,录制会失败 -4. 如果完全异步,录制会成功 - -**改动前**: -- ❌ 录制失败:因为 `indices->to(cpu_device)` 触发同步 - -**改动后**: -- ✅ 录制成功:使用设备端 CUDA kernel,完全异步 - -### 测试 2: 设备端输入支持 - -**目的**:验证 embedding 是否支持设备端输入 - -**工作原理**: -1. 创建设备端的 `input_ids` -2. 直接调用 `embedding.forward(input_ids)` -3. 检查是否成功且输出在设备上 - -**改动前**: -- ❌ 可能需要先将输入拷贝到 CPU(同步操作) - -**改动后**: -- ✅ 直接支持设备端输入(完全异步) - -### 测试 3: 异步操作验证(备用) - -**目的**:当 CUDA Graph API 不可用时,使用事件验证异步性 - -**工作原理**: -1. 使用 `DeviceEvent` 记录操作时间 -2. 检查操作是否立即完成(同步)或异步执行 - -**改动前**: -- ⚠️ 事件立即完成,说明有同步操作 - -**改动后**: -- ✅ 事件未立即完成,说明是异步操作 - ---- - -## 🛠️ 故障排查 - -### 问题 1: PyTorch 版本不支持 CUDA Graph - -**现象**: -``` -⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法 -``` - -**解决**: -- 需要 PyTorch 2.0+ 版本 -- 测试会自动降级到简化验证方法 -- 简化验证也能检测是否支持图录制 - -### 问题 2: CUDA 不可用 - -**现象**: -``` -⚠ CUDA 不可用,跳过图录制测试 -``` - -**解决**: -- 确保 CUDA 设备可用 -- 测试需要 CUDA 环境 - -### 问题 3: 测试失败但不确定原因 - -**检查清单**: -1. ✅ 确认代码已编译(特别是 CUDA 支持) -2. ✅ 确认 CUDA 设备可用 -3. ✅ 检查 `src/infinicore/nn/embedding.cc` 是否还有 `to(cpu_device)` -4. ✅ 检查是否有 `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` - ---- - -## 💡 快速验证脚本 - -创建一个简单的验证脚本: - -```bash -#!/bin/bash -# quick_check.sh - -cd /home/zhuyue/codes/InfiniCore - -echo "=== 1. 代码检查 ===" -if grep -q "to(cpu_device)" src/infinicore/nn/embedding.cc; then - echo "❌ 改动前:发现同步操作 to(cpu_device)" -else - echo "✅ 改动后:无同步操作" -fi - -echo "" -echo "=== 2. 设备端实现检查 ===" -if [ -f "src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu" ]; then - echo "✅ 改动后:有设备端 CUDA kernel" -else - echo "❌ 改动前:无设备端 CUDA kernel" -fi - -echo "" -echo "=== 3. 运行测试 ===" -python test/infinicore/nn/test_embedding_graph_recording.py -``` - -使用方法: -```bash -chmod +x quick_check.sh -./quick_check.sh -``` - ---- - -## 📋 总结 - -### 改动前特征 - -| 检查项 | 结果 | -|--------|------| -| 代码中有 `to(cpu_device)` | ✅ 有 | -| 有设备端 CUDA kernel | ❌ 无 | -| 图录制测试 | ❌ 失败 | -| 设备端输入 | ❌ 失败 | - -### 改动后特征 - -| 检查项 | 结果 | -|--------|------| -| 代码中有 `to(cpu_device)` | ❌ 无 | -| 有设备端 CUDA kernel | ✅ 有 | -| 图录制测试 | ✅ 成功 | -| 设备端输入 | ✅ 成功 | - -### 最简单的判断方法 - -**运行测试脚本**,查看 "CUDA Graph 录制" 测试结果: -- ✅ **通过** → 支持图录制(改动后) -- ❌ **失败** → 不支持图录制(改动前) - diff --git a/test/infinicore/nn/test_embedding_graph_recording.py b/test/infinicore/nn/test_embedding_graph_recording.py deleted file mode 100644 index 405f71e0d..000000000 --- a/test/infinicore/nn/test_embedding_graph_recording.py +++ /dev/null @@ -1,284 +0,0 @@ -""" -测试 embedding 是否支持 CUDA Graph 录制 - -使用方法: - python test/infinicore/nn/test_embedding_graph_recording.py - -关键验证点: -1. 改动前:indices->to(cpu_device) 会触发同步的 D2H 拷贝,导致图录制失败 -2. 改动后:使用设备端 CUDA kernel,完全异步,支持图录制 - -预期结果: -- 改动前:图录制失败,设备端输入可能失败 -- 改动后:图录制成功,设备端输入成功 -""" - -import infinicore -import torch -import ctypes - - -def test_embedding_graph_recording(): - """测试 embedding 是否支持 CUDA Graph 录制""" - print("=" * 60) - print("测试 Embedding 图录制支持") - print("=" * 60) - - # 检查是否有 CUDA - if not torch.cuda.is_available(): - print("⚠ CUDA 不可用,跳过图录制测试") - return False - - device = infinicore.device("cuda", 0) - - # 创建 embedding 模块 - vocab_size = 1000 - embedding_dim = 128 - embedding = infinicore.nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - dtype=infinicore.float32, - device=device - ) - - # 创建设备端的 input_ids(这是关键:改动前不支持,改动后支持) - batch_size = 4 - seq_len = 32 - input_ids_device = infinicore.from_list( - [[i % vocab_size for i in range(seq_len)] for _ in range(batch_size)], - dtype=infinicore.int64, - device=device - ) - - print(f"\n1. 输入张量信息:") - print(f" - Shape: {input_ids_device.shape}") - print(f" - Device: {input_ids_device.device.type}") - print(f" - Dtype: {input_ids_device.dtype}") - - # 尝试使用 CUDA Graph 录制 - print(f"\n2. 尝试 CUDA Graph 录制...") - - # 使用 PyTorch 的 CUDA Graph API 进行测试(更简单可靠) - try: - # 设置设备 - infinicore.set_device(device) - - # 使用 PyTorch 的 CUDA Graph API - # 注意:PyTorch 2.0+ 支持 torch.cuda.graph - try: - # 方法 1: 使用 PyTorch 的 CUDA Graph(推荐) - print(" 使用 PyTorch CUDA Graph API 测试...") - - # 创建 warmup 输入 - warmup_input = input_ids_device - - # Warmup(图录制前需要先执行一次,包括内存分配) - warmup_output = embedding.forward(warmup_input) - infinicore.sync_stream() # 同步确保 warmup 完成 - - # 预先分配输出张量(CUDA Graph 不支持动态内存分配) - # 输出形状: input_shape + [embedding_dim] - output_shape = list(input_ids_device.shape) + [embedding_dim] - output = infinicore.empty( - output_shape, - dtype=embedding.weight.dtype, - device=device - ) - - # Warmup embedding(确保内存分配完成) - import infinicore.nn.functional as F - F.embedding(warmup_input, embedding.weight, out=output) - infinicore.sync_stream() - - # 开始图录制(使用预先分配的 output) - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - # 使用 embedding 的 out 参数(in-place),传入预先分配的 output - F.embedding(input_ids_device, embedding.weight, out=output) - - print(" ✓ 成功完成图录制!") - print(" ✓ Embedding 支持 CUDA Graph 录制") - - # 验证图可以重复执行 - graph.replay() - infinicore.sync_stream() - - print(" ✓ 图可以成功重放") - return True - - except AttributeError: - # PyTorch 版本可能不支持 torch.cuda.graph - print(" ⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法") - return test_embedding_async_verification(embedding, input_ids_device) - except RuntimeError as e: - error_msg = str(e) - if "capture" in error_msg.lower() or "graph" in error_msg.lower(): - print(f" ✗ 图录制失败: {e}") - print(" ✗ Embedding 不支持 CUDA Graph 录制(可能包含同步操作)") - return False - else: - print(f" ⚠ 图录制测试异常: {e}") - return test_embedding_async_verification(embedding, input_ids_device) - - except Exception as e: - print(f" ⚠ 图录制测试异常: {e}") - print(" 使用简化验证方法...") - import traceback - traceback.print_exc() - return test_embedding_async_verification(embedding, input_ids_device) - - -def test_embedding_async_verification(embedding, input_ids_device): - """ - 简化验证:检查是否有同步操作 - - 关键检查点: - 1. 输入是否可以在设备上(改动前需要 CPU,改动后支持设备) - 2. 操作是否完全异步(没有同步点) - """ - print("\n3. 简化验证:检查异步操作支持") - - # 验证 1: 输入可以在设备上 - if input_ids_device.device.type != "cuda": - print(" ✗ 输入不在设备上,无法验证") - return False - - print(" ✓ 输入在设备上") - - # 验证 2: 执行 forward,检查是否有同步操作 - # 如果改动前,这里会调用 indices->to(cpu_device),触发同步 - # 如果改动后,直接使用设备端 kernel,完全异步 - - try: - # 记录开始时间 - start_event = infinicore.DeviceEvent(enable_timing=True) - end_event = infinicore.DeviceEvent(enable_timing=True) - - start_event.record() - output = embedding.forward(input_ids_device) - end_event.record() - - # 不立即同步,检查操作是否异步 - # 如果操作是异步的,query 应该返回 False(未完成) - # 如果操作是同步的,可能已经完成 - - # 等待一小段时间 - import time - time.sleep(0.001) # 1ms - - # 检查事件状态 - is_complete = end_event.query() - - if not is_complete: - print(" ✓ 操作是异步的(事件未立即完成)") - else: - print(" ⚠ 操作可能包含同步点(事件立即完成)") - - # 同步并测量时间 - end_event.synchronize() - elapsed = start_event.elapsed_time(end_event) - - print(f" ✓ Forward 执行时间: {elapsed:.3f} ms") - print(f" ✓ 输出形状: {output.shape}") - print(f" ✓ 输出设备: {output.device.type}") - - # 验证输出正确性 - embedding_dim = embedding.embedding_dim() - expected_shape = (*input_ids_device.shape, embedding_dim) - if output.device.type == "cuda" and output.shape == expected_shape: - print(" ✓ 输出在设备上,形状正确") - return True - else: - print(f" ✗ 输出验证失败") - print(f" 期望形状: {expected_shape}, 实际形状: {output.shape}") - print(f" 期望设备: cuda, 实际设备: {output.device.type}") - return False - - except Exception as e: - print(f" ✗ 验证失败: {e}") - import traceback - traceback.print_exc() - return False - - -def test_embedding_device_input_support(): - """测试 embedding 是否支持设备端输入""" - print("\n" + "=" * 60) - print("测试 Embedding 设备端输入支持") - print("=" * 60) - - if not torch.cuda.is_available(): - print("⚠ CUDA 不可用,跳过测试") - return False - - device = infinicore.device("cuda", 0) - vocab_size = 100 - embedding_dim = 64 - - embedding = infinicore.nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - dtype=infinicore.float32, - device=device - ) - - # 测试 1: 设备端输入(改动后支持) - print("\n测试 1: 设备端输入") - try: - input_ids_device = infinicore.from_list( - [[1, 2, 3, 4, 5]], - dtype=infinicore.int64, - device=device - ) - output = embedding.forward(input_ids_device) - print(f" ✓ 设备端输入成功") - print(f" - 输入设备: {input_ids_device.device.type}") - print(f" - 输出设备: {output.device.type}") - print(f" - 输出形状: {output.shape}") - return True - except Exception as e: - print(f" ✗ 设备端输入失败: {e}") - return False - - -def main(): - """主测试函数""" - print("\n" + "=" * 60) - print("Embedding 图录制支持验证") - print("=" * 60) - - results = [] - - # 测试 1: 图录制支持 - result1 = test_embedding_graph_recording() - results.append(("CUDA Graph 录制", result1)) - - # 测试 2: 设备端输入支持 - result2 = test_embedding_device_input_support() - results.append(("设备端输入", result2)) - - # 总结 - print("\n" + "=" * 60) - print("测试结果总结") - print("=" * 60) - - all_passed = True - for test_name, result in results: - status = "✓ 通过" if result else "✗ 失败" - print(f"{test_name}: {status}") - if not result: - all_passed = False - - print("\n" + "=" * 60) - if all_passed: - print("✓ 所有测试通过!Embedding 支持图录制") - else: - print("✗ 部分测试失败,Embedding 可能不完全支持图录制") - print("=" * 60) - - return all_passed - - -if __name__ == "__main__": - success = main() - exit(0 if success else 1) From f9761a299922c2b789f3b8151f34f86e54b20c87 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 19 Jan 2026 09:43:05 +0800 Subject: [PATCH 06/25] issue/900 - maintains classic embedding for devices yet to be worked on --- src/infinicore/nn/embedding.cc | 89 ++++++++++++++++++++--- src/infinicore/ops/embedding/embedding.cc | 4 +- 2 files changed, 79 insertions(+), 14 deletions(-) diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index 6aa86a4fa..75475b410 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -43,20 +43,87 @@ Embedding::Embedding(size_t num_embeddings, } Tensor Embedding::forward(const Tensor &indices) const { - // Ensure indices are on the same device as weight - // This avoids synchronous memcpy in ops layer which would hurt performance - Tensor indices_on_device = indices; - if (indices->device() != device_) { - indices_on_device = indices->to(device_); + // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach + auto device_type = device_.getType(); + if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE) { + // Use op::embedding which supports device-side input and batch dimension + return op::embedding(indices->contiguous()->to(device_), weight_); } - // Ensure indices are contiguous for efficient access - // op::embedding now supports device-side input for graph recording - Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous(); + // Get the shape of indices + auto indices_shape = indices->shape(); - // Use op::embedding which now supports device-side input and batch dimension - // This enables full graph recording support without synchronization - return op::embedding(indices_contiguous, weight_); + // Output shape: indices_shape + [embedding_dim] + std::vector output_shape = indices_shape; + output_shape.push_back(embedding_dim_); + + // Create output tensor on the same device as weight + auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device()); + + // Flatten indices for sequential row copies + auto cpu_device = Device(Device::Type::CPU, 0); + auto indices_cpu = indices->to(cpu_device)->contiguous(); + + // Calculate total number of lookups + size_t num_lookups = 1; + for (auto dim : indices_shape) { + num_lookups *= dim; + } + + const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype()); + + // Source and destination base pointers + auto *weight_base = weight_->data(); + auto *out_base = out->data(); + + // Helper lambda to read index based on dtype with bounds checking + auto read_index = [&](size_t i) -> int64_t { + auto dtype = indices_cpu->dtype(); + if (dtype == DataType::I32) { + const auto *data = reinterpret_cast(indices_cpu->data()); + return static_cast(data[i]); + } else if (dtype == DataType::I64) { + const auto *data = reinterpret_cast(indices_cpu->data()); + return data[i]; + } else if (dtype == DataType::U32) { + const auto *data = reinterpret_cast(indices_cpu->data()); + return static_cast(data[i]); + } else if (dtype == DataType::U64) { + const auto *data = reinterpret_cast(indices_cpu->data()); + uint64_t val = data[i]; + // Check if value can fit in int64_t + if (val > static_cast(std::numeric_limits::max())) { + throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val)); + } + return static_cast(val); + } else { + throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast(dtype))); + } + }; + + if (weight_->device().getType() == Device::Type::CPU) { + // CPU path: memcpy row by row + for (size_t i = 0; i < num_lookups; ++i) { + int64_t idx = read_index(i); + if (idx < 0 || idx >= static_cast(num_embeddings_)) { + throw std::out_of_range( + "Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")"); + } + std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); + } + } else { + // Device path: use stream-ordered D2D copies + for (size_t i = 0; i < num_lookups; ++i) { + int64_t idx = read_index(i); + if (idx < 0 || idx >= static_cast(num_embeddings_)) { + throw std::out_of_range( + "Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")"); + } + context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); + } + } + + return out; } std::string Embedding::extra_repr() const { diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index 2dfd3aa21..4d4da708d 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -1,8 +1,6 @@ #include "infinicore/ops/embedding.hpp" + #include "../../utils.hpp" -#include "infinicore/context/context.hpp" -#include -#include namespace infinicore::op { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding); From 0c204dfdcd332eaf441ef2cb09bf85120f226296 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 23 Jan 2026 01:32:36 +0000 Subject: [PATCH 07/25] issue/791 fix add_rmsnorm api and rmsnorm module --- include/infinicore/nn/rmsnorm.hpp | 23 ++++- include/infinicore/ops/add_rms_norm.hpp | 14 ++-- include/infiniop/ops/add_rms_norm.h | 6 +- python/infinicore/__init__.py | 2 +- python/infinicore/ops/add_rms_norm.py | 29 ++----- src/infinicore/nn/rmsnorm.cc | 18 ++++ .../ops/add_rms_norm/add_rms_norm.cc | 24 +++--- .../ops/add_rms_norm/add_rms_norm_infiniop.cc | 71 ++++++++-------- src/infiniop/ops/add_rms_norm/add_rms_norm.h | 6 +- .../ops/add_rms_norm/cpu/add_rms_norm_cpu.cc | 30 +++---- src/infiniop/ops/add_rms_norm/info.h | 6 +- .../nvidia/add_rms_norm_nvidia.cu | 10 +-- src/infiniop/ops/add_rms_norm/operator.cc | 12 +-- test/infinicore/ops/add_rms_norm.py | 84 ++++++++++++------- test/infiniop/add_rms_norm.py | 40 +++++++-- test/infiniop/libinfiniop/op_register.py | 2 + 16 files changed, 225 insertions(+), 152 deletions(-) diff --git a/include/infinicore/nn/rmsnorm.hpp b/include/infinicore/nn/rmsnorm.hpp index 212b2a6e4..5891819eb 100644 --- a/include/infinicore/nn/rmsnorm.hpp +++ b/include/infinicore/nn/rmsnorm.hpp @@ -1,7 +1,7 @@ #pragma once -#include "module.hpp" #include "../ops.hpp" +#include "module.hpp" namespace infinicore::nn { @@ -57,6 +57,21 @@ class RMSNorm : public Module { */ Tensor forward(const Tensor &x) const; + /** + * @brief Forward pass: apply RMSNorm in-place with residual + * + * @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions. + * Will be modified in-place to the normalized output. + * @param residual Residual tensor to add to input before normalization. + * Will be modified in-place to the sum of input and residual. + * + * The normalization is applied over the last dimension. + * For example: + * Input: [batch, seq_len, hidden_size] -> normalize over hidden_size + * Input: [batch, hidden_size] -> normalize over hidden_size + */ + void forward_inplace(Tensor &x, Tensor &residual) const; + // Module information size_t normalized_shape() const { return normalized_shape_; } double eps() const { return eps_; } @@ -73,9 +88,9 @@ class RMSNorm : public Module { INFINICORE_NN_PARAMETER(weight); private: - size_t normalized_shape_; // Size of the feature dimension - double eps_; // Epsilon for numerical stability - DataType dtype_; // Data type for weight + size_t normalized_shape_; // Size of the feature dimension + double eps_; // Epsilon for numerical stability + DataType dtype_; // Data type for weight }; } // namespace infinicore::nn diff --git a/include/infinicore/ops/add_rms_norm.hpp b/include/infinicore/ops/add_rms_norm.hpp index e8a955a3c..50064e0a4 100644 --- a/include/infinicore/ops/add_rms_norm.hpp +++ b/include/infinicore/ops/add_rms_norm.hpp @@ -5,16 +5,14 @@ #include namespace infinicore::op { -class AddRMSNorm { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float); - static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float); // Fused Add and RMS Normalization // Returns: (normalized_result, add_result) // The add_result can be used as residual for subsequent layers -std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); -void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); +std::pair add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f); +void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f); +// Fused Add and RMS Normalization (inplace) +// normalized_result wil be stored in input, add_result will be stored in residual +void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f); } // namespace infinicore::op diff --git a/include/infiniop/ops/add_rms_norm.h b/include/infiniop/ops/add_rms_norm.h index 7742c1343..52cd096a6 100644 --- a/include/infiniop/ops/add_rms_norm.h +++ b/include/infiniop/ops/add_rms_norm.h @@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor( infiniopHandle_t handle, infiniopAddRMSNormDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc); + float epsilon); __C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size); @@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de void *workspace, size_t workspace_size, void *y, + void *residual_out, const void *a, const void *b, const void *weight, - void *residual_out, void *stream); __C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc); diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index c6b01d5aa..52a269ce5 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -43,7 +43,7 @@ uint8, ) from infinicore.ops.add import add -from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_ +from infinicore.ops.add_rms_norm import add_rms_norm from infinicore.ops.attention import attention from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul diff --git a/python/infinicore/ops/add_rms_norm.py b/python/infinicore/ops/add_rms_norm.py index 4ad347812..a5de7bd92 100644 --- a/python/infinicore/ops/add_rms_norm.py +++ b/python/infinicore/ops/add_rms_norm.py @@ -1,8 +1,8 @@ +import infinicore.tensor as tensor from infinicore.lib import _infinicore -from infinicore.tensor import Tensor -def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): +def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None, residual=None): """ Fused Add and RMS Normalization. @@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): The add_result can be used as residual for subsequent layers. """ if out is None: - result = _infinicore.add_rms_norm( - a._underlying, b._underlying, weight._underlying, epsilon - ) - return (Tensor(result[0]), Tensor(result[1])) + out = tensor.empty(a.shape, dtype=a.dtype, device=a.device) + if residual is None: + residual = tensor.empty(b.shape, dtype=b.dtype, device=b.device) - y, residual_out = out _infinicore.add_rms_norm_( - y._underlying, - residual_out._underlying, + out._underlying, + residual._underlying, a._underlying, b._underlying, weight._underlying, epsilon, ) - return (y, residual_out) - -def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5): - """In-place Fused Add and RMS Normalization.""" - _infinicore.add_rms_norm_( - y._underlying, - residual_out._underlying, - a._underlying, - b._underlying, - weight._underlying, - epsilon, - ) + return out, residual diff --git a/src/infinicore/nn/rmsnorm.cc b/src/infinicore/nn/rmsnorm.cc index a83c3a113..107dac44a 100644 --- a/src/infinicore/nn/rmsnorm.cc +++ b/src/infinicore/nn/rmsnorm.cc @@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const { return op::rms_norm(x, weight_, static_cast(eps_)); } +void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const { + if (!residual) { + residual = x; + x = op::rms_norm(x, weight_, static_cast(eps_)); + } else { + if (device_.getType() == Device::Type::CPU + || device_.getType() == Device::Type::NVIDIA + || device_.getType() == Device::Type::ILUVATAR + || device_.getType() == Device::Type::METAX + || device_.getType() == Device::Type::MOORE) { + op::add_rms_norm_inplace(x, residual, weight_, static_cast(eps_)); + } else { + op::add_(residual, x, residual); + op::rms_norm_(x, residual, weight_, static_cast(eps_)); + } + } +} + std::string RMSNorm::extra_repr() const { return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; } diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc index 650ce87e6..ccba62e21 100644 --- a/src/infinicore/ops/add_rms_norm/add_rms_norm.cc +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc @@ -4,26 +4,30 @@ namespace infinicore::op { -common::OpDispatcher &AddRMSNorm::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm); -void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { +AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon); } -std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { +void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon); +} + +std::pair add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device()); add_rms_norm_(y, residual_out, a, b, weight, epsilon); return std::make_pair(y, residual_out); } -void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { - AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon); +void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { + AddRMSNorm::execute(out, residual, a, b, weight, epsilon); +} + +void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) { + add_rms_norm_(input, residual, input, residual, weight, epsilon); } } // namespace infinicore::op diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc index d6540a039..53d30a2c7 100644 --- a/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc @@ -1,50 +1,53 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/add_rms_norm.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::add_rms_norm_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopAddRMSNormDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, residual, a, b, weight; + float epsilon; +}; -void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { +void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, AddRMSNorm, + seed, y->desc(), residual_out->desc(), + a->desc(), b->desc(), weight->desc(), epsilon); + + INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor); - auto desc_opt = cache.get(seed); - infiniopAddRMSNormDescriptor_t desc = nullptr; + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(y), + graph::GraphTensor(residual_out), + graph::GraphTensor(a), + graph::GraphTensor(b), + graph::GraphTensor(weight), + epsilon}; - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( - context::getInfiniopHandle(device), &desc, - y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return planned; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( - desc, workspace->data(), workspace_size, - y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); + planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), + planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - AddRMSNorm::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup); } // namespace infinicore::op::add_rms_norm_impl::infiniop diff --git a/src/infiniop/ops/add_rms_norm/add_rms_norm.h b/src/infiniop/ops/add_rms_norm/add_rms_norm.h index c5d63333d..76451e982 100644 --- a/src/infiniop/ops/add_rms_norm/add_rms_norm.h +++ b/src/infiniop/ops/add_rms_norm/add_rms_norm.h @@ -33,19 +33,19 @@ infiniopHandle_t handle, \ Descriptor **desc_ptr, \ infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t residual_out_desc, \ infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t b_desc, \ infiniopTensorDescriptor_t weight_desc, \ - float epsilon, \ - infiniopTensorDescriptor_t residual_out_desc); \ + float epsilon); \ \ infiniStatus_t calculate( \ void *workspace, size_t workspace_size, \ void *y, \ + void *residual_out, \ const void *a, \ const void *b, \ const void *weight, \ - void *residual_out, \ void *stream) const; \ }; \ } diff --git a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc index 5e7954b71..a3099c5c4 100644 --- a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc +++ b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc @@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } template -infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) { +infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const T *w) { const size_t batch_size = info->shape[0]; const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1; const size_t dim = info->dim(); @@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T } template -infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) { +infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const Tw *w) { static_assert(std::is_same::value || std::is_same::value, "T must be fp16_t or bf16_t"); @@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { if (_info.atype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight)); } else if (_info.wtype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight)); } else if (_info.wtype == INFINI_DTYPE_BF16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } } else if (_info.atype == INFINI_DTYPE_BF16) { if (_info.wtype == INFINI_DTYPE_BF16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight)); } else if (_info.wtype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight)); } else if (_info.wtype == INFINI_DTYPE_F16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } } else if (_info.atype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out)); + CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (float *)residual_out, (const float *)a, (const float *)b, (const float *)weight)); } else if (_info.atype == INFINI_DTYPE_F64) { - CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out)); + CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (double *)residual_out, (const double *)a, (const double *)b, (const double *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/add_rms_norm/info.h b/src/infiniop/ops/add_rms_norm/info.h index abe1b5059..883aed343 100644 --- a/src/infiniop/ops/add_rms_norm/info.h +++ b/src/infiniop/ops/add_rms_norm/info.h @@ -16,9 +16,9 @@ class AddRMSNormInfo { float epsilon; std::vector shape; std::vector y_strides; + std::vector residual_out_strides; std::vector a_strides; std::vector b_strides; - std::vector residual_out_strides; bool has_residual_out; size_t ndim() const { return shape.size(); } @@ -26,11 +26,11 @@ class AddRMSNormInfo { static utils::Result create( infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { + float epsilon) { auto atype = y_desc->dtype(); auto wtype = weight_desc->dtype(); diff --git a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu index 03601205f..8fddf5958 100644 --- a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu +++ b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu @@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); auto info = result.take(); @@ -122,8 +122,8 @@ infiniStatus_t launchKernel( infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { if (workspace_size < _workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc index a856e5447..11c0aef99 100644 --- a/src/infiniop/ops/add_rms_norm/operator.cc +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -32,12 +32,12 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( infiniopHandle_t handle, infiniopAddRMSNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { + float epsilon) { #define CREATE(CASE, NAMESPACE) \ case CASE: \ @@ -45,11 +45,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( handle, \ reinterpret_cast(desc_ptr), \ y_desc, \ + residual_out_desc, \ a_desc, \ b_desc, \ weight_desc, \ - epsilon, \ - residual_out_desc) + epsilon) switch (handle->device) { #ifdef ENABLE_CPU_API @@ -116,16 +116,16 @@ __C infiniStatus_t infiniopAddRMSNorm( void *workspace, size_t workspace_size, void *y, + void *residual_out, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) { #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream) + ->calculate(workspace, workspace_size, y, residual_out, a, b, weight, stream) switch (desc->device_type) { diff --git a/test/infinicore/ops/add_rms_norm.py b/test/infinicore/ops/add_rms_norm.py index 429d9df25..f6bf165a9 100644 --- a/test/infinicore/ops/add_rms_norm.py +++ b/test/infinicore/ops/add_rms_norm.py @@ -30,8 +30,24 @@ ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (2048, 8192, 1), + (2048, 8192, 1), + (2048, 8192, 1), + ), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (16384, 4096, 1), + (16384, 4096, 1), + (16384, 4096, 1), + ), ] # Tolerance configuration @@ -87,12 +103,14 @@ def parse_test_cases(): y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype) # Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result) - residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) + residual_out_spec = TensorSpec.from_tensor( + a_shape, a_strides, input_dtype + ) test_cases.append( TestCase( inputs=[a_spec, b_spec, w_spec], kwargs={"epsilon": _EPSILON}, - output_specs=[y_spec, residual_out_spec], # Two outputs + output_specs=None, # Two outputs comparison_target=None, tolerance=tolerance, output_count=2, # Two outputs: normalized_result and add_result @@ -101,19 +119,25 @@ def parse_test_cases(): ) # Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w)) - if y_supports_inplace: - residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) - test_cases.append( - TestCase( - inputs=[a_spec, b_spec, w_spec], - kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)}, - output_specs=[y_spec, residual_out_spec], # Two outputs - comparison_target="out", - tolerance=tolerance, - output_count=2, - description=f"AddRMSNorm - INPLACE(out)", - ) - ) + # if y_supports_inplace: + # residual_out_spec = TensorSpec.from_tensor( + # a_shape, a_strides, input_dtype + # ) + # test_cases.append( + # TestCase( + # inputs=[a_spec, b_spec, w_spec], + # kwargs={ + # "epsilon": _EPSILON, + # "out": y_spec, + # "residual": residual_out_spec, + # }, + # output_specs=[y_spec, residual_out_spec], # Two outputs + # comparison_target="out", + # tolerance=tolerance, + # output_count=2, + # description=f"AddRMSNorm - INPLACE(out)", + # ) + # ) return test_cases @@ -127,7 +151,9 @@ def __init__(self): def get_test_cases(self): return parse_test_cases() - def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + def torch_operator( + self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs + ): """PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)""" input_dtype = a.dtype @@ -144,21 +170,19 @@ def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): add_result = sum_tensor.to(input_dtype) if out is not None: - # For in-place operations, we need to handle the output tuple - if isinstance(out, (tuple, list)) and len(out) == 2: - out[0].copy_(normalized_result) - out[1].copy_(add_result) - return tuple(out) - else: - # Single output - just return normalized result for backward compatibility - out.copy_(normalized_result) - return out - + out.copy_(normalized_result) + if residual is not None: + residual.copy_(add_result) + return (normalized_result, add_result) - def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + def infinicore_operator( + self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs + ): """InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)""" - return infinicore.add_rms_norm(a, b, weight, epsilon, out=out) + return infinicore.add_rms_norm( + a, b, weight, epsilon, out=out, residual=residual + ) def main(): diff --git a/test/infiniop/add_rms_norm.py b/test/infiniop/add_rms_norm.py index 930314761..e3b4f9b64 100644 --- a/test/infiniop/add_rms_norm.py +++ b/test/infiniop/add_rms_norm.py @@ -32,8 +32,24 @@ ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (2048, 8192, 1), + (2048, 8192, 1), + (2048, 8192, 1), + ), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (16384, 4096, 1), + (16384, 4096, 1), + (16384, 4096, 1), + ), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None), ] @@ -97,7 +113,9 @@ def test( w = TestTensor(w_shape, None, w_dtype, device) eps = 1e-6 - add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps) + add_rms_norm( + y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps + ) if sync is not None: sync() @@ -109,11 +127,11 @@ def test( handle, ctypes.byref(descriptor), y.descriptor, + residual_out.descriptor, a.descriptor, b.descriptor, w.descriptor, eps, - residual_out.descriptor, ) ) @@ -136,10 +154,10 @@ def lib_add_rms_norm(): workspace.data(), workspace_size.value, y.data(), + residual_out.data(), a.data(), b.data(), w.data(), - residual_out.data(), None, ) ) @@ -147,18 +165,22 @@ def lib_add_rms_norm(): lib_add_rms_norm() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - + # Verify normalized result (y) if DEBUG: debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) - + # Verify add result (residual_out) - should be a + b - expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32) + expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to( + torch.float32 + ) expected_residual = expected_residual.to(a.torch_tensor().dtype) if DEBUG: debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) - assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) + assert torch.allclose( + residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol + ) # Profiling workflow if PROFILE: diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 618be2b05..7d6cf17e2 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -393,6 +393,7 @@ def add_rms_norm_(lib): infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, c_float, ] @@ -412,6 +413,7 @@ def add_rms_norm_(lib): c_void_p, c_void_p, c_void_p, + c_void_p, ] lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32 From dfafc21f357303d5625d08976c00027b079ce267 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 7 Jan 2026 18:59:25 +0800 Subject: [PATCH 08/25] issue/884 - add_rms_norm on iluvatar, metax and moore --- .../add_rms_norm/metax/add_rms_norm_metax.cuh | 8 + .../metax/add_rms_norm_metax.maca | 167 ++++++++++++++++ .../add_rms_norm/moore/add_rms_norm_moore.h | 8 + .../add_rms_norm/moore/add_rms_norm_moore.mu | 183 ++++++++++++++++++ .../nvidia/add_rms_norm_nvidia.cu | 14 +- src/infiniop/ops/add_rms_norm/operator.cc | 30 ++- 6 files changed, 403 insertions(+), 7 deletions(-) create mode 100644 src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh create mode 100644 src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca create mode 100644 src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h create mode 100644 src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh new file mode 100644 index 000000000..3d6b13b53 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_METAX_CUH__ +#define __ADD_RMS_NORM_METAX_CUH__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca new file mode 100644 index 000000000..8339ec5aa --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca @@ -0,0 +1,167 @@ +#include "../../../devices/metax/metax_common.h" +#include "add_rms_norm_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +// Kernel function template for add_rms_norm on Metax platform +template +INFINIOP_METAX_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::metax { + +// Internal opaque structure for Metax device handle +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor +Descriptor::~Descriptor() { + delete _opaque; +} + +// Create descriptor for add_rms_norm operator +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// Launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + hcStream_t stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + // Handle different data type combinations following Metax pattern + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__hpcc_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__hpcc_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +// Main calculation function +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *a, const void *b, const void *weight, + void *residual_out, void *stream_) const { + + // Check workspace size + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + // Extract tensor strides and dimensions + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto stream = reinterpret_cast(stream_); + + // Launch kernel with appropriate block size based on device capability + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::metax diff --git a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h new file mode 100644 index 000000000..9d3f810f2 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_MOORE_H__ +#define __ADD_RMS_NORM_MOORE_H__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(moore) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu new file mode 100644 index 000000000..fe7a49765 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu @@ -0,0 +1,183 @@ +#include "../../../devices/moore/moore_common.h" +#include "add_rms_norm_moore.h" + +#include "../../../devices/moore/moore_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +// Kernel function template for add_rms_norm on Moore platform +template +INFINIOP_MOORE_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::moore { + +// Internal opaque structure for Moore device handle +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor +Descriptor::~Descriptor() { + delete _opaque; +} + +// Create descriptor for add_rms_norm operator +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// Launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + musaStream_t musa_stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + // Handle different data type combinations + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __mt_bfloat16, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__mt_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__mt_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +// Main calculation function +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *a, const void *b, const void *weight, + void *residual_out, void *stream) const { + + // Check workspace size + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + // Extract tensor strides and dimensions + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto musa_stream = reinterpret_cast(stream); + + // Launch kernel with appropriate block size based on device capability + if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::moore diff --git a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu index 8fddf5958..652f8adf3 100644 --- a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu +++ b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu @@ -143,7 +143,15 @@ infiniStatus_t Descriptor::calculate( auto cuda_stream = reinterpret_cast(stream); // launch kernel with different block sizes - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel( batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, @@ -151,8 +159,8 @@ infiniStatus_t Descriptor::calculate( a, stride_a_batch, stride_a_nhead, b, stride_b_batch, stride_b_nhead, weight, _info.wtype, _info.epsilon, cuda_stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { - CHECK_STATUS(launchKernel( + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, residual_out, stride_residual_out_batch, stride_residual_out_nhead, diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc index 11c0aef99..62187cf34 100644 --- a/src/infiniop/ops/add_rms_norm/operator.cc +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -17,12 +17,10 @@ // #include "bang/add_rms_norm_bang.h" #endif #ifdef ENABLE_METAX_API -// TODO: Add Metax implementation -// #include "metax/add_rms_norm_metax.cuh" +#include "metax/add_rms_norm_metax.cuh" #endif #ifdef ENABLE_MOORE_API -// TODO: Add Moore implementation -// #include "moore/add_rms_norm_moore.h" +#include "moore/add_rms_norm_moore.h" #endif #ifdef ENABLE_KUNLUN_API // TODO: Add Kunlun implementation @@ -61,6 +59,12 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -94,6 +98,12 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -138,6 +148,12 @@ __C infiniStatus_t infiniopAddRMSNorm( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -173,6 +189,12 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip #ifdef ENABLE_ILUVATAR_API DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia); #endif From 4ddc6647a0ed0c6413ae590a5688471293a89c4d Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 19 Jan 2026 10:55:16 +0000 Subject: [PATCH 09/25] issue/632 - adapt to iluvatar core 20 --- .../devices/nvidia/nvidia_kernel_common.cuh | 1 + .../causal_softmax/nvidia/causal_softmax_nvidia.cu | 14 +++++++++----- .../ops/rms_norm/nvidia/rms_norm_nvidia.cu | 8 +++++--- xmake.lua | 6 ++++++ xmake/iluvatar.lua | 3 +++ 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh b/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh index f11643b42..02cee1ebf 100644 --- a/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh +++ b/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh @@ -14,6 +14,7 @@ // Posible maximum number of threads per block for CUDA architectures // Used for picking correct kernel launch configuration #define CUDA_BLOCK_SIZE_4096 4096 +#define CUDA_BLOCK_SIZE_2048 2048 #define CUDA_BLOCK_SIZE_1024 1024 #define CUDA_BLOCK_SIZE_512 512 diff --git a/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu b/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu index 6dae5af61..6e671df1b 100644 --- a/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu +++ b/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu @@ -76,7 +76,15 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, const void *x, void *stream_) const { cudaStream_t stream = (cudaStream_t)stream_; - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, + _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, + _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel( y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); @@ -84,10 +92,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, CHECK_STATUS(launchKernel( y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { - CHECK_STATUS(launchKernel( - y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, - _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } diff --git a/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu b/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu index b083650d4..21cda3695 100644 --- a/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu +++ b/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu @@ -117,12 +117,14 @@ infiniStatus_t Descriptor::calculate( auto cuda_stream = reinterpret_cast(stream); // launch kernel with different block sizes - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { - CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } diff --git a/xmake.lua b/xmake.lua index d5a4ba7f7..a51435325 100644 --- a/xmake.lua +++ b/xmake.lua @@ -114,6 +114,12 @@ option("iluvatar-gpu") set_description("Whether to compile implementations for Iluvatar GPU") option_end() +option("ivcore-20") + set_default(false) + set_showmenu(true) + set_description("Use ivcore20") +option_end() + if has_config("iluvatar-gpu") then add_defines("ENABLE_ILUVATAR_API") includes("xmake/iluvatar.lua") diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index 35ccf2154..57a935f4f 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -44,6 +44,9 @@ target("infiniop-iluvatar") set_warnings("all", "error") add_cuflags("-Wno-error=unused-private-field") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) + if has_config("ivcore-20") then + add_cuflags("--cuda-gpu-arch=ivcore20", {force = true}) + end add_culdflags("-fPIC") add_cxflags("-fPIC") From 0611cb1bd9a7ef5659f3d9d6ef1877c7d5e4074e Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 08:06:21 +0000 Subject: [PATCH 10/25] issue/791 - fix add_rmsnorm api on mtx and mth --- .../devices/metax/metax_kernel_common.h | 4 +- .../devices/moore/moore_kernel_common.h | 1 + .../metax/add_rms_norm_metax.maca | 54 +++++++++++++------ .../add_rms_norm/moore/add_rms_norm_moore.mu | 10 ++-- 4 files changed, 48 insertions(+), 21 deletions(-) diff --git a/src/infiniop/devices/metax/metax_kernel_common.h b/src/infiniop/devices/metax/metax_kernel_common.h index f81358d28..d850e9d04 100644 --- a/src/infiniop/devices/metax/metax_kernel_common.h +++ b/src/infiniop/devices/metax/metax_kernel_common.h @@ -8,8 +8,10 @@ // Posible maximum number of threads per block for METAX architectures // Used for picking correct kernel launch configuration -#define METAX_BLOCK_SIZE_1024 1024 #define METAX_BLOCK_SIZE_512 512 +#define METAX_BLOCK_SIZE_1024 1024 +#define METAX_BLOCK_SIZE_2048 2048 +#define METAX_BLOCK_SIZE_4096 4096 #define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess) diff --git a/src/infiniop/devices/moore/moore_kernel_common.h b/src/infiniop/devices/moore/moore_kernel_common.h index e0aea4148..d72cfb197 100644 --- a/src/infiniop/devices/moore/moore_kernel_common.h +++ b/src/infiniop/devices/moore/moore_kernel_common.h @@ -6,6 +6,7 @@ // Posible maximum number of threads per block for MUSA architectures // Used for picking correct kernel launch configuration +#define MOORE_BLOCK_SIZE_4096 4096 #define MOORE_BLOCK_SIZE_2048 2048 #define MOORE_BLOCK_SIZE_1024 1024 #define MOORE_BLOCK_SIZE_512 512 diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca index 8339ec5aa..04355e927 100644 --- a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca @@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); auto info = result.take(); @@ -104,16 +104,16 @@ infiniStatus_t launchKernel( // Handle different data type combinations following Metax pattern if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { LAUNCH_KERNEL(half, half, float); - } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { - LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float); - } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { - LAUNCH_KERNEL(__hpcc_bfloat16, float, float); - } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { - LAUNCH_KERNEL(half, float, float); } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { LAUNCH_KERNEL(half, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float); } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { LAUNCH_KERNEL(__hpcc_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__hpcc_bfloat16, float, float); } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { LAUNCH_KERNEL(float, float, float); } else { @@ -128,8 +128,8 @@ infiniStatus_t launchKernel( // Main calculation function infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream_) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { // Check workspace size if (workspace_size < _workspace_size) { @@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate( auto dim = _info.dim(); uint32_t batch_size = static_cast(_info.shape[0]); size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; - auto stream = reinterpret_cast(stream_); + auto stream_ = reinterpret_cast(stream); - // Launch kernel with appropriate block size based on device capability - if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + // Launch kernel with different block sizes + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream_)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel( batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, residual_out, stride_residual_out_batch, stride_residual_out_nhead, a, stride_a_batch, stride_a_nhead, b, stride_b_batch, stride_b_nhead, - weight, _info.wtype, _info.epsilon, stream)); + weight, _info.wtype, _info.epsilon, stream_)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream_)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream_)); } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } diff --git a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu index fe7a49765..90c027ead 100644 --- a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu +++ b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu @@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); auto info = result.take(); @@ -128,8 +128,8 @@ infiniStatus_t launchKernel( // Main calculation function infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { // Check workspace size if (workspace_size < _workspace_size) { From 81e5fe948a80902f80bb56e99bed248a35248063 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Mon, 19 Jan 2026 10:38:32 +0000 Subject: [PATCH 11/25] issue/810 support more ops as graph op --- include/infinicore/graph/graph.hpp | 23 +- include/infinicore/ops/add.hpp | 15 +- include/infinicore/ops/causal_softmax.hpp | 14 +- .../infinicore/ops/distributed/allreduce.hpp | 24 ++ include/infinicore/ops/gemm.hpp | 6 +- include/infinicore/ops/mul.hpp | 14 +- include/infinicore/ops/paged_attention.hpp | 18 +- include/infinicore/ops/paged_caching.hpp | 10 +- include/infinicore/ops/rearrange.hpp | 14 +- include/infinicore/ops/rms_norm.hpp | 14 +- include/infinicore/ops/rope.hpp | 27 +- include/infinicore/ops/swiglu.hpp | 15 +- src/infinicore/graph/graph.cc | 4 +- src/infinicore/nn/linear.cc | 19 +- src/infinicore/ops/add/add.cc | 18 +- src/infinicore/ops/add/add_infiniop.cc | 70 ++-- .../ops/causal_softmax/causal_softmax.cc | 30 +- .../causal_softmax/causal_softmax_infiniop.cc | 67 ++-- src/infinicore/ops/distributed/allreduce.cc | 50 +++ src/infinicore/ops/gemm/gemm.cc | 8 +- src/infinicore/ops/gemm/gemm_infiniop.cc | 2 +- src/infinicore/ops/infiniop_impl.hpp | 57 ++- src/infinicore/ops/mul/mul.cc | 19 +- src/infinicore/ops/mul/mul_infiniop.cc | 69 ++-- .../ops/paged_attention/paged_attention.cc | 30 +- .../paged_attention_infiniop.cc | 106 +++--- .../ops/paged_caching/paged_caching.cc | 17 +- .../paged_caching/paged_caching_infiniop.cc | 91 ++--- src/infinicore/ops/rearrange/rearrange.cc | 24 +- .../ops/rearrange/rearrange_infiniop.cc | 59 ++- src/infinicore/ops/rms_norm/rms_norm.cc | 20 +- .../ops/rms_norm/rms_norm_infiniop.cc | 79 ++-- src/infinicore/ops/rope/rope.cc | 49 +-- src/infinicore/ops/rope/rope_infiniop.cc | 110 +++--- src/infinicore/ops/swiglu/swiglu.cc | 31 +- src/infinicore/ops/swiglu/swiglu_infiniop.cc | 89 ++--- src/infinicore/utils.hpp | 1 + test/infinicore/graph/attention.py | 356 ++++++++++++++++++ test/infinicore/graph/graph.py | 85 ----- 39 files changed, 1074 insertions(+), 680 deletions(-) create mode 100644 include/infinicore/ops/distributed/allreduce.hpp create mode 100644 src/infinicore/ops/distributed/allreduce.cc create mode 100644 test/infinicore/graph/attention.py delete mode 100644 test/infinicore/graph/graph.py diff --git a/include/infinicore/graph/graph.hpp b/include/infinicore/graph/graph.hpp index d997e0224..6f50cf730 100644 --- a/include/infinicore/graph/graph.hpp +++ b/include/infinicore/graph/graph.hpp @@ -15,10 +15,15 @@ class GraphTensor : public Tensor { }; class GraphOperator { +public: + virtual void run() const = 0; + virtual ~GraphOperator() = default; +}; +class DispatchableGraphOperator : public GraphOperator { public: - void run() const; - ~GraphOperator(); + void run() const override; + ~DispatchableGraphOperator() override; protected: using run_schema = void (*)(void *); @@ -49,7 +54,7 @@ class Graph { } // namespace infinicore::graph #define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \ - class __OP_NAME__ : public graph::GraphOperator { \ + class __OP_NAME__ : public graph::DispatchableGraphOperator { \ public: \ using schema = void (*)(__VA_ARGS__); \ using plan_schema = void *(*)(__VA_ARGS__); \ @@ -79,12 +84,12 @@ class Graph { runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \ deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__); -#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \ - auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \ - if (context::isGraphRecording()) { \ - context::addGraphOperator(op); \ - } else { \ - op->run(); \ +#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \ + auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \ + if (context::isGraphRecording()) { \ + context::addGraphOperator(___op); \ + } else { \ + ___op->run(); \ } #define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \ diff --git a/include/infinicore/ops/add.hpp b/include/infinicore/ops/add.hpp index 1dd5df0ff..528cca18a 100644 --- a/include/infinicore/ops/add.hpp +++ b/include/infinicore/ops/add.hpp @@ -1,17 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Add { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor c, Tensor a, Tensor b); - static common::OpDispatcher &dispatcher(); -}; -Tensor add(Tensor a, Tensor b); -void add_(Tensor c, Tensor a, Tensor b); -Tensor operator+(Tensor a, Tensor b); +INFINICORE_GRAPH_OP_CLASS(Add, Tensor, const Tensor &, const Tensor &); + +Tensor add(const Tensor &a, const Tensor &b); +void add_(Tensor c, const Tensor &a, const Tensor &b); + } // namespace infinicore::op diff --git a/include/infinicore/ops/causal_softmax.hpp b/include/infinicore/ops/causal_softmax.hpp index ae40d521c..2646852af 100644 --- a/include/infinicore/ops/causal_softmax.hpp +++ b/include/infinicore/ops/causal_softmax.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class CausalSoftmax { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor output, Tensor input); - static common::OpDispatcher &dispatcher(); -}; -Tensor causal_softmax(Tensor input); -void causal_softmax_(Tensor output, Tensor input); +INFINICORE_GRAPH_OP_CLASS(CausalSoftmax, Tensor, const Tensor &); + +Tensor causal_softmax(const Tensor &input); +void causal_softmax_(Tensor output, const Tensor &input); + } // namespace infinicore::op diff --git a/include/infinicore/ops/distributed/allreduce.hpp b/include/infinicore/ops/distributed/allreduce.hpp new file mode 100644 index 000000000..39f74243a --- /dev/null +++ b/include/infinicore/ops/distributed/allreduce.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "../../device.hpp" +#include "../../graph/graph.hpp" +#include "../common/op.hpp" + +#include + +namespace infinicore::op::distributed { +class AllReduce : public graph::GraphOperator { +public: + AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); + ~AllReduce(); + void run() const override; + static void execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); + +private: + void *planned_meta_; +}; + +Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); +void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); + +} // namespace infinicore::op::distributed diff --git a/include/infinicore/ops/gemm.hpp b/include/infinicore/ops/gemm.hpp index 481d47cf6..4f76cee26 100644 --- a/include/infinicore/ops/gemm.hpp +++ b/include/infinicore/ops/gemm.hpp @@ -6,9 +6,9 @@ namespace infinicore::op { -INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float); +INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, const Tensor &, const Tensor &, float, float); -Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f); -void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta); +Tensor gemm(const Tensor &a, const Tensor &b, float alpha = 1.0f, float beta = 0.0f); +void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta); } // namespace infinicore::op diff --git a/include/infinicore/ops/mul.hpp b/include/infinicore/ops/mul.hpp index 83416bbd9..2eb480ddb 100644 --- a/include/infinicore/ops/mul.hpp +++ b/include/infinicore/ops/mul.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Mul { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor c, Tensor a, Tensor b); - static common::OpDispatcher &dispatcher(); -}; -Tensor mul(Tensor a, Tensor b); -void mul_(Tensor c, Tensor a, Tensor b); +INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &); + +Tensor mul(const Tensor &a, const Tensor &b); +void mul_(Tensor c, const Tensor &a, const Tensor &b); + } // namespace infinicore::op diff --git a/include/infinicore/ops/paged_attention.hpp b/include/infinicore/ops/paged_attention.hpp index 54d61fa89..8c906c95e 100644 --- a/include/infinicore/ops/paged_attention.hpp +++ b/include/infinicore/ops/paged_attention.hpp @@ -1,18 +1,20 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" #include namespace infinicore::op { -class PagedAttention { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional, float); - static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional, float); + +Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale); + +void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale); -Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale); -void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale); } // namespace infinicore::op diff --git a/include/infinicore/ops/paged_caching.hpp b/include/infinicore/ops/paged_caching.hpp index e357cda38..403b4b738 100644 --- a/include/infinicore/ops/paged_caching.hpp +++ b/include/infinicore/ops/paged_caching.hpp @@ -1,17 +1,13 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class PagedCaching { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); - static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &); -void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping); +void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping); } // namespace infinicore::op diff --git a/include/infinicore/ops/rearrange.hpp b/include/infinicore/ops/rearrange.hpp index 3576365e0..5db983ef8 100644 --- a/include/infinicore/ops/rearrange.hpp +++ b/include/infinicore/ops/rearrange.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Rearrange { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor y, Tensor x); - static common::OpDispatcher &dispatcher(); -}; -Tensor rearrange(Tensor x); -void rearrange_(Tensor y, Tensor x); +INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &); + +Tensor rearrange(const Tensor &x); +void rearrange_(Tensor y, const Tensor &x); + } // namespace infinicore::op diff --git a/include/infinicore/ops/rms_norm.hpp b/include/infinicore/ops/rms_norm.hpp index 1212c446e..c7b2b2d72 100644 --- a/include/infinicore/ops/rms_norm.hpp +++ b/include/infinicore/ops/rms_norm.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class RMSNorm { -public: - using schema = void (*)(Tensor, Tensor, Tensor, float); - static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f); - static common::OpDispatcher &dispatcher(); -}; -Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f); -void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f); +INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float); + +Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f); +void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f); + } // namespace infinicore::op diff --git a/include/infinicore/ops/rope.hpp b/include/infinicore/ops/rope.hpp index a5f7792b9..8fd630ce1 100644 --- a/include/infinicore/ops/rope.hpp +++ b/include/infinicore/ops/rope.hpp @@ -1,21 +1,28 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "../nn/rope.hpp" #include "../tensor.hpp" #include "common/op.hpp" namespace infinicore::op { -class RoPE { -public: - using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo); - static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo); - static common::OpDispatcher &dispatcher(); -}; -// Internal function -void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo); +INFINICORE_GRAPH_OP_CLASS(RoPE, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo); + +// Internal +void rope_(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo); + +// Public API +Tensor rope(const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo); -// Public API that uses infinicore::nn::RoPE::Algo -Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo); } // namespace infinicore::op diff --git a/include/infinicore/ops/swiglu.hpp b/include/infinicore/ops/swiglu.hpp index 47a3e0f44..7aa77e632 100644 --- a/include/infinicore/ops/swiglu.hpp +++ b/include/infinicore/ops/swiglu.hpp @@ -1,16 +1,15 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" +#include "../tensor.hpp" #include "common/op.hpp" namespace infinicore::op { -class SwiGLU { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor c, Tensor a, Tensor b); - static common::OpDispatcher &dispatcher(); -}; -Tensor swiglu(Tensor a, Tensor b); -void swiglu_(Tensor c, Tensor a, Tensor b); +INFINICORE_GRAPH_OP_CLASS(SwiGLU, Tensor, const Tensor &, const Tensor &); + +Tensor swiglu(const Tensor &a, const Tensor &b); +void swiglu_(Tensor c, const Tensor &a, const Tensor &b); + } // namespace infinicore::op diff --git a/src/infinicore/graph/graph.cc b/src/infinicore/graph/graph.cc index 8218b1b48..8a06e5f40 100644 --- a/src/infinicore/graph/graph.cc +++ b/src/infinicore/graph/graph.cc @@ -17,11 +17,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) { * GraphOperator * ========================= */ -void GraphOperator::run() const { +void DispatchableGraphOperator::run() const { runner_(planned_meta_); } -GraphOperator::~GraphOperator() { +DispatchableGraphOperator::~DispatchableGraphOperator() { if (deleter_) { deleter_(&planned_meta_); } diff --git a/src/infinicore/nn/linear.cc b/src/infinicore/nn/linear.cc index bb4fc29b1..0be993699 100644 --- a/src/infinicore/nn/linear.cc +++ b/src/infinicore/nn/linear.cc @@ -1,6 +1,7 @@ #include "infinicore/nn/linear.hpp" #include "../utils.hpp" #include "infinicore/ops.hpp" +#include "infinicore/ops/distributed/allreduce.hpp" #include "infinicore/ops/linear.hpp" #include #include @@ -102,9 +103,6 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur } else { bias_ = Parameter(); // Default constructed empty parameter } - - // SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", - // in_features, out_features, bias, static_cast(dtype_)); } Tensor ColumnParallelLinear::forward(Tensor &input) const { @@ -138,26 +136,13 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo } else { bias_ = Parameter(); // Default constructed empty parameter } - - // SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", - // in_features, out_features, bias, static_cast(dtype_)); } Tensor RowParallelLinear::forward(Tensor &input) const { auto output = BaseLinear::forward(input); if ((tp_size_ > 1) && (communicator_ != nullptr)) { - - Size count = output->numel(); - DataType type = output->dtype(); - - infinirtStream_t stream = infinicore::context::getStream(); - - INFINICORE_CHECK_ERROR(infinicclAllReduce(output->data(), output->data(), count, static_cast(static_cast(type)), - INFINICCL_SUM, communicator_, stream)); - INFINICORE_CHECK_ERROR(infinirtStreamSynchronize(stream)); - - // RUN_INFINI(infinirtStreamSynchronize(stream)); + op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_); } return output; } diff --git a/src/infinicore/ops/add/add.cc b/src/infinicore/ops/add/add.cc index ef776d632..815a2de27 100644 --- a/src/infinicore/ops/add/add.cc +++ b/src/infinicore/ops/add/add.cc @@ -3,24 +3,24 @@ namespace infinicore::op { -common::OpDispatcher &Add::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Add); -void Add::execute(Tensor c, Tensor a, Tensor b) { +Add::Add(Tensor c, const Tensor &a, const Tensor &b) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); - infinicore::context::setDevice(c->device()); - dispatcher().lookup(c->device().getType())(c, a, b); + INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b); } -Tensor add(Tensor a, Tensor b) { +void Add::execute(Tensor c, const Tensor &a, const Tensor &b) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Add, c, a, b); +} + +Tensor add(const Tensor &a, const Tensor &b) { auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); add_(c, a, b); return c; } -void add_(Tensor c, Tensor a, Tensor b) { +void add_(Tensor c, const Tensor &a, const Tensor &b) { Add::execute(c, a, b); } diff --git a/src/infinicore/ops/add/add_infiniop.cc b/src/infinicore/ops/add/add_infiniop.cc index 29c36770c..bb377d667 100644 --- a/src/infinicore/ops/add/add_infiniop.cc +++ b/src/infinicore/ops/add/add_infiniop.cc @@ -1,50 +1,52 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/add.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::add_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopAddDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyAddDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Add, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, c, a, b; +}; -void calculate(Tensor c, Tensor a, Tensor b) { +void *plan(Tensor c, const Tensor &a, const Tensor &b) { size_t seed = hash_combine(c, b, a); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Add, + seed, + c->desc(), a->desc(), b->desc()); - auto desc_opt = cache.get(seed); - infiniopAddDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Add, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( - context::getInfiniopHandle(device), &desc, - c->desc(), a->desc(), b->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(c), + graph::GraphTensor(a), + graph::GraphTensor(b)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetAddWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopAdd( - desc, workspace->data(), workspace_size, - c->data(), a->data(), b->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->c->data(), + planned->a->data(), + planned->b->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Add::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Add, &plan, &run, &cleanup); } // namespace infinicore::op::add_impl::infiniop diff --git a/src/infinicore/ops/causal_softmax/causal_softmax.cc b/src/infinicore/ops/causal_softmax/causal_softmax.cc index 3194dff94..328ff390e 100644 --- a/src/infinicore/ops/causal_softmax/causal_softmax.cc +++ b/src/infinicore/ops/causal_softmax/causal_softmax.cc @@ -1,37 +1,27 @@ #include "infinicore/ops/causal_softmax.hpp" - #include "../../utils.hpp" -#include - namespace infinicore::op { -common::OpDispatcher &CausalSoftmax::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(CausalSoftmax); -void CausalSoftmax::execute(Tensor output, Tensor input) { +CausalSoftmax::CausalSoftmax(Tensor output, const Tensor &input) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); - infinicore::context::setDevice(output->device()); - auto device_type = output->device().getType(); - auto func = dispatcher().lookup(device_type); - - if (func == nullptr) { - throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast(device_type))); - } + INFINICORE_GRAPH_OP_DISPATCH(output->device().getType(), output, input); +} - func(output, input); +void CausalSoftmax::execute(Tensor output, const Tensor &input) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(CausalSoftmax, output, input); } -Tensor causal_softmax(Tensor input) { - Shape shape = input->shape(); - auto output = Tensor::empty(shape, input->dtype(), input->device()); +Tensor causal_softmax(const Tensor &input) { + auto output = Tensor::empty(input->shape(), input->dtype(), input->device()); causal_softmax_(output, input); return output; } -void causal_softmax_(Tensor output, Tensor input) { +void causal_softmax_(Tensor output, const Tensor &input) { CausalSoftmax::execute(output, input); } + } // namespace infinicore::op diff --git a/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc index 082d0e642..e0a0595fb 100644 --- a/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc +++ b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc @@ -1,50 +1,49 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/causal_softmax.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::causal_softmax_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopCausalSoftmaxDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, CausalSoftmax, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, output, input; +}; -void calculate(Tensor output, Tensor input) { +void *plan(Tensor output, const Tensor &input) { size_t seed = hash_combine(output, input); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, CausalSoftmax, + seed, output->desc(), input->desc()); - auto desc_opt = cache.get(seed); - infiniopCausalSoftmaxDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, CausalSoftmax, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( - context::getInfiniopHandle(device), &desc, - output->desc(), input->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(output), + graph::GraphTensor(input)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopCausalSoftmax( - desc, workspace->data(), workspace_size, - output->data(), input->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->output->data(), + planned->input->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - CausalSoftmax::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(CausalSoftmax, &plan, &run, &cleanup); } // namespace infinicore::op::causal_softmax_impl::infiniop diff --git a/src/infinicore/ops/distributed/allreduce.cc b/src/infinicore/ops/distributed/allreduce.cc new file mode 100644 index 000000000..ddfc238c9 --- /dev/null +++ b/src/infinicore/ops/distributed/allreduce.cc @@ -0,0 +1,50 @@ +#include "infinicore/ops/distributed/allreduce.hpp" +#include "../../utils.hpp" + +namespace infinicore::op::distributed { + +struct PlannedMeta { + graph::GraphTensor output, input; + infinicclReduceOp_t op; + infinicclComm_t communicator; +}; + +AllReduce::AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); + INFINICORE_ASSERT(output->is_contiguous() && input->is_contiguous()); + INFINICORE_ASSERT(output->numel() == input->numel()); + planned_meta_ = new PlannedMeta{graph::GraphTensor(output), graph::GraphTensor(input), op, communicator}; +} +AllReduce::~AllReduce() { + if (planned_meta_) { + PlannedMeta *meta = reinterpret_cast(planned_meta_); + delete meta; + } +} + +void AllReduce::run() const { + PlannedMeta *meta = reinterpret_cast(planned_meta_); + + INFINICORE_CHECK_ERROR(infinicclAllReduce(meta->input->data(), + meta->output->data(), + meta->input->numel(), + static_cast(static_cast(meta->input->dtype())), + meta->op, + meta->communicator, + infinicore::context::getStream())); +} + +void AllReduce::execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(AllReduce, output, input, op, communicator); +} + +Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + auto output = Tensor::empty(input->shape(), input->dtype(), input->device()); + allreduce_(output, input, op, communicator); + return output; +} + +void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + AllReduce::execute(output, input, op, communicator); +} +} // namespace infinicore::op::distributed diff --git a/src/infinicore/ops/gemm/gemm.cc b/src/infinicore/ops/gemm/gemm.cc index e2b3924f7..765bc869f 100644 --- a/src/infinicore/ops/gemm/gemm.cc +++ b/src/infinicore/ops/gemm/gemm.cc @@ -5,16 +5,16 @@ namespace infinicore::op { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm); -Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +Gemm::Gemm(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta); } -void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +void Gemm::execute(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta); } -Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { +Tensor gemm(const Tensor &a, const Tensor &b, float alpha, float beta) { Shape shape = a->shape(); Size size = a->ndim(); shape[size - 1] = b->size(size - 1); @@ -23,7 +23,7 @@ Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { return c; } -void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { Gemm::execute(c, a, b, alpha, beta); } diff --git a/src/infinicore/ops/gemm/gemm_infiniop.cc b/src/infinicore/ops/gemm/gemm_infiniop.cc index 670fdbc2a..33a7271c0 100644 --- a/src/infinicore/ops/gemm/gemm_infiniop.cc +++ b/src/infinicore/ops/gemm/gemm_infiniop.cc @@ -11,7 +11,7 @@ struct PlannedMeta { float alpha, beta; }; -void *plan(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +void *plan(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { size_t seed = hash_combine(c, a, b); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( diff --git a/src/infinicore/ops/infiniop_impl.hpp b/src/infinicore/ops/infiniop_impl.hpp index 2bf38c8c6..67c09554c 100644 --- a/src/infinicore/ops/infiniop_impl.hpp +++ b/src/infinicore/ops/infiniop_impl.hpp @@ -5,23 +5,46 @@ #include "infinicore/ops/common/cache.hpp" #include -#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \ - struct __DESC_TYPE__ { \ - infiniop##__OP_NAME__##Descriptor_t desc; \ - Descriptor(infiniop##__OP_NAME__##Descriptor_t desc) : desc(desc) {} \ - ~Descriptor() { \ - if (desc != nullptr) { \ - infiniopDestroy##__OP_NAME__##Descriptor(desc); \ - desc = nullptr; \ - } \ - } \ - }; \ - \ - thread_local common::OpCache> \ - caches( \ - __SIZE__, \ - [](std::shared_ptr<__DESC_TYPE__> &desc) { \ - desc = nullptr; \ +#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \ + struct __DESC_TYPE__ { \ + infiniop##__OP_NAME__##Descriptor_t desc = nullptr; \ + \ + explicit __DESC_TYPE__(infiniop##__OP_NAME__##Descriptor_t d) \ + : desc(d) {} \ + \ + /* non-copyable */ \ + __DESC_TYPE__(const __DESC_TYPE__ &) = delete; \ + __DESC_TYPE__ &operator=(const __DESC_TYPE__ &) = delete; \ + \ + /* movable */ \ + __DESC_TYPE__(__DESC_TYPE__ &&other) noexcept \ + : desc(other.desc) { \ + other.desc = nullptr; \ + } \ + \ + __DESC_TYPE__ &operator=(__DESC_TYPE__ &&other) noexcept { \ + if (this != &other) { \ + if (desc != nullptr) { \ + infiniopDestroy##__OP_NAME__##Descriptor(desc); \ + } \ + desc = other.desc; \ + other.desc = nullptr; \ + } \ + return *this; \ + } \ + \ + ~__DESC_TYPE__() { \ + if (desc != nullptr) { \ + infiniopDestroy##__OP_NAME__##Descriptor(desc); \ + } \ + } \ + }; \ + \ + thread_local common::OpCache> \ + caches( \ + __SIZE__, \ + [](std::shared_ptr<__DESC_TYPE__> &desc) { \ + desc = nullptr; \ }); #define INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(__DESC_TYPE__, __DESC_NAME__, __INFINIOP_NAME__, __HASH_KEY__, ...) \ diff --git a/src/infinicore/ops/mul/mul.cc b/src/infinicore/ops/mul/mul.cc index 736e44269..6923fed9c 100644 --- a/src/infinicore/ops/mul/mul.cc +++ b/src/infinicore/ops/mul/mul.cc @@ -1,27 +1,26 @@ #include "infinicore/ops/mul.hpp" - #include "../../utils.hpp" namespace infinicore::op { -common::OpDispatcher &Mul::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Mul); -void Mul::execute(Tensor c, Tensor a, Tensor b) { +Mul::Mul(Tensor c, const Tensor &a, const Tensor &b) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); - infinicore::context::setDevice(c->device()); - dispatcher().lookup(c->device().getType())(c, a, b); + INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b); +} + +void Mul::execute(Tensor c, const Tensor &a, const Tensor &b) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Mul, c, a, b); } -Tensor mul(Tensor a, Tensor b) { +Tensor mul(const Tensor &a, const Tensor &b) { auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); mul_(c, a, b); return c; } -void mul_(Tensor c, Tensor a, Tensor b) { +void mul_(Tensor c, const Tensor &a, const Tensor &b) { Mul::execute(c, a, b); } diff --git a/src/infinicore/ops/mul/mul_infiniop.cc b/src/infinicore/ops/mul/mul_infiniop.cc index 885a5f842..39a7bd87d 100644 --- a/src/infinicore/ops/mul/mul_infiniop.cc +++ b/src/infinicore/ops/mul/mul_infiniop.cc @@ -1,50 +1,51 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/mul.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::mul_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopMulDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Mul, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, c, a, b; +}; -void calculate(Tensor c, Tensor a, Tensor b) { +void *plan(Tensor c, const Tensor &a, const Tensor &b) { size_t seed = hash_combine(c, b, a); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Mul, + seed, c->desc(), a->desc(), b->desc()); - auto desc_opt = cache.get(seed); - infiniopMulDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Mul, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( - context::getInfiniopHandle(device), &desc, - c->desc(), a->desc(), b->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(c), + graph::GraphTensor(a), + graph::GraphTensor(b)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopMul( - desc, workspace->data(), workspace_size, - c->data(), a->data(), b->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->c->data(), + planned->a->data(), + planned->b->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Mul::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Mul, &plan, &run, &cleanup); } // namespace infinicore::op::mul_impl::infiniop diff --git a/src/infinicore/ops/paged_attention/paged_attention.cc b/src/infinicore/ops/paged_attention/paged_attention.cc index 171614087..60de2ae66 100644 --- a/src/infinicore/ops/paged_attention/paged_attention.cc +++ b/src/infinicore/ops/paged_attention/paged_attention.cc @@ -1,27 +1,37 @@ #include "infinicore/ops/paged_attention.hpp" - #include "../../utils.hpp" namespace infinicore::op { -common::OpDispatcher &PagedAttention::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(PagedAttention); -void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { +PagedAttention::PagedAttention(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens); - infinicore::context::setDevice(out->device()); - dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); +} + +void PagedAttention::execute(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + PagedAttention, + out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); } -Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { +Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); paged_attention_(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); return out; } -void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { +void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { PagedAttention::execute(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); } diff --git a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc index 3d367c5bb..733733a6b 100644 --- a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc +++ b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc @@ -1,54 +1,68 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/paged_attention.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::paged_attention_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopPagedAttentionDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { - size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopPagedAttentionDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor( - context::getInfiniopHandle(device), &desc, - out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), kv_lens->desc(), - alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr, - scale)); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); - - INFINICORE_CHECK_ERROR(infiniopPagedAttention( - desc, workspace->data(), workspace_size, - out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), kv_lens->data(), - alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, - context::getStream())); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, PagedAttention, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, q, k_cache, v_cache, block_tables, cache_lens; + std::optional alibi_slopes; + float scale; +}; + +void *plan(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &cache_lens, + std::optional alibi_slopes, float scale) { + size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, PagedAttention, + seed, + out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), + block_tables->desc(), cache_lens->desc(), + alibi_slopes ? alibi_slopes.value()->desc() : nullptr, + scale); + + INFINIOP_WORKSPACE_TENSOR(workspace, PagedAttention, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(block_tables), + graph::GraphTensor(cache_lens), + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopPagedAttention( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->out->data(), + p->q->data(), + p->k_cache->data(), + p->v_cache->data(), + p->block_tables->data(), + p->cache_lens->data(), + p->alibi_slopes.has_value() ? p->alibi_slopes.value()->data() : nullptr, + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - PagedAttention::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(PagedAttention, &plan, &run, &cleanup); } // namespace infinicore::op::paged_attention_impl::infiniop diff --git a/src/infinicore/ops/paged_caching/paged_caching.cc b/src/infinicore/ops/paged_caching/paged_caching.cc index cc14bf236..afc8bf0c6 100644 --- a/src/infinicore/ops/paged_caching/paged_caching.cc +++ b/src/infinicore/ops/paged_caching/paged_caching.cc @@ -1,21 +1,20 @@ #include "infinicore/ops/paged_caching.hpp" - #include "../../utils.hpp" namespace infinicore::op { -common::OpDispatcher &PagedCaching::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(PagedCaching); -void PagedCaching::execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { +PagedCaching::PagedCaching(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping); - infinicore::context::setDevice(k_cache->device()); - dispatcher().lookup(k_cache->device().getType())(k_cache, v_cache, k, v, slot_mapping); + INFINICORE_GRAPH_OP_DISPATCH(k->device().getType(), k_cache, v_cache, k, v, slot_mapping); +} + +void PagedCaching::execute(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(PagedCaching, k_cache, v_cache, k, v, slot_mapping); } -void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { +void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping); } diff --git a/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc b/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc index 7dcaf47a0..5e8be049a 100644 --- a/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc +++ b/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc @@ -1,50 +1,57 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/paged_caching.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::paged_caching_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopPagedCachingDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyPagedCachingDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { - size_t seed = hash_combine(k_cache, v_cache, k, v, slot_mapping); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopPagedCachingDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor( - context::getInfiniopHandle(device), &desc, - k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetPagedCachingWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); - - INFINICORE_CHECK_ERROR(infiniopPagedCaching( - desc, workspace->data(), workspace_size, - k_cache->data(), v_cache->data(), k->data(), v->data(), slot_mapping->data(), context::getStream())); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, PagedCaching, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + + graph::GraphTensor workspace, k_cache, v_cache, k, v, slot_mapping; +}; + +void *plan(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { + size_t key = hash_combine(k_cache, v_cache, k, v, slot_mapping); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, PagedCaching, + key, k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, PagedCaching, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(slot_mapping)}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopPagedCaching( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->k_cache->data(), + p->v_cache->data(), + p->k->data(), + p->v->data(), + p->slot_mapping->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - PagedCaching::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(PagedCaching, &plan, &run, &cleanup); } // namespace infinicore::op::paged_caching_impl::infiniop diff --git a/src/infinicore/ops/rearrange/rearrange.cc b/src/infinicore/ops/rearrange/rearrange.cc index c70a9e930..191d0871f 100644 --- a/src/infinicore/ops/rearrange/rearrange.cc +++ b/src/infinicore/ops/rearrange/rearrange.cc @@ -3,24 +3,30 @@ namespace infinicore::op { -common::OpDispatcher &Rearrange::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Rearrange); -void Rearrange::execute(Tensor y, Tensor x) { +Rearrange::Rearrange(Tensor y, const Tensor &x) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, x); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, x); } -Tensor rearrange(Tensor x) { +void Rearrange::execute(Tensor y, const Tensor &x) { + auto op = std::make_shared(y, x); + if (context::isGraphRecording()) { + context::addGraphOperator(op); + } else { + op->run(); + } +} + +Tensor rearrange(const Tensor &x) { auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); rearrange_(y, x); return y; } -void rearrange_(Tensor y, Tensor x) { +void rearrange_(Tensor y, const Tensor &x) { Rearrange::execute(y, x); } + } // namespace infinicore::op diff --git a/src/infinicore/ops/rearrange/rearrange_infiniop.cc b/src/infinicore/ops/rearrange/rearrange_infiniop.cc index 71b43f027..f30b09e79 100644 --- a/src/infinicore/ops/rearrange/rearrange_infiniop.cc +++ b/src/infinicore/ops/rearrange/rearrange_infiniop.cc @@ -1,47 +1,46 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/rearrange.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::rearrange_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopRearrangeDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyRearrangeDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Rearrange, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor y, x; +}; -void calculate(Tensor y, Tensor x) { +void *plan(Tensor y, const Tensor &x) { size_t seed = hash_combine(y, x); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Rearrange, + seed, y->desc(), + x->desc()); - auto desc_opt = cache.get(seed); - infiniopRearrangeDescriptor_t desc = nullptr; + return new PlannedMeta{ + descriptor, + graph::GraphTensor(y), + graph::GraphTensor(x)}; +} - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(device), &desc, y->desc(), x->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR( infiniopRearrange( - desc, - y->data(), - x->data(), + planned->descriptor->desc, + planned->y->data(), + planned->x->data(), context::getStream())); } -static bool registered = []() { - Rearrange::dispatcher().registerAll(&calculate, false); - return true; -}(); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Rearrange, &plan, &run, &cleanup); } // namespace infinicore::op::rearrange_impl::infiniop diff --git a/src/infinicore/ops/rms_norm/rms_norm.cc b/src/infinicore/ops/rms_norm/rms_norm.cc index 20e598056..0f8f2e57b 100644 --- a/src/infinicore/ops/rms_norm/rms_norm.cc +++ b/src/infinicore/ops/rms_norm/rms_norm.cc @@ -1,27 +1,25 @@ #include "infinicore/ops/rms_norm.hpp" - #include "../../utils.hpp" namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RMSNorm); -common::OpDispatcher &RMSNorm::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void RMSNorm::execute(Tensor y, Tensor x, Tensor weight, float epsilon) { +RMSNorm::RMSNorm(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x, weight); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, x, weight, epsilon); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, x, weight, epsilon); +} + +void RMSNorm::execute(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(RMSNorm, y, x, weight, epsilon); } -Tensor rms_norm(Tensor x, Tensor weight, float epsilon) { +Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon) { auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); rms_norm_(y, x, weight, epsilon); return y; } -void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon) { +void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { RMSNorm::execute(y, x, weight, epsilon); } diff --git a/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc b/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc index 17b0ad888..9e4622a28 100644 --- a/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc +++ b/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc @@ -1,50 +1,55 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/rms_norm.hpp" -#include -namespace infinicore::op::rms_norm_impl::infiniop { +#include "../infiniop_impl.hpp" -thread_local common::OpCache caches( - 100, // capacity - [](infiniopRMSNormDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyRMSNormDescriptor(desc)); - desc = nullptr; - } - }); +namespace infinicore::op::rms_norm_impl::infiniop { -void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) { - size_t seed = hash_combine(y, x, weight, epsilon); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RMSNorm, 100); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, y, x, weight; +}; - auto desc_opt = cache.get(seed); - infiniopRMSNormDescriptor_t desc = nullptr; +void *plan(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { + size_t seed = hash_combine(y, x, weight, epsilon); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor( - context::getInfiniopHandle(device), &desc, - y->desc(), x->desc(), weight->desc(), epsilon)); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, RMSNorm, + seed, y->desc(), + x->desc(), + weight->desc(), + epsilon); + + INFINIOP_WORKSPACE_TENSOR(workspace, RMSNorm, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(y), + graph::GraphTensor(x), + graph::GraphTensor(weight)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopRMSNorm( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->y->data(), + planned->x->data(), + planned->weight->data(), + context::getStream())); +} - INFINICORE_CHECK_ERROR(infiniopRMSNorm( - desc, workspace->data(), workspace_size, - y->data(), x->data(), weight->data(), context::getStream())); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - RMSNorm::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RMSNorm, &plan, &run, &cleanup); } // namespace infinicore::op::rms_norm_impl::infiniop diff --git a/src/infinicore/ops/rope/rope.cc b/src/infinicore/ops/rope/rope.cc index e0a187db3..d28951b7d 100644 --- a/src/infinicore/ops/rope/rope.cc +++ b/src/infinicore/ops/rope/rope.cc @@ -1,37 +1,44 @@ #include "infinicore/ops/rope.hpp" - #include "../../utils.hpp" -#include "infinicore/context/context.hpp" - -#include namespace infinicore::op { -common::OpDispatcher &RoPE::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RoPE); -void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { +RoPE::RoPE(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x_out, x, pos, sin_table, cos_table); - infinicore::context::setDevice(x_out->device()); - auto device_type = x_out->device().getType(); - auto func = dispatcher().lookup(device_type); - - if (func == nullptr) { - throw std::runtime_error("No RoPE implementation found for device type: " + std::to_string(static_cast(device_type))); - } + INFINICORE_GRAPH_OP_DISPATCH(x_out->device().getType(), x_out, x, pos, sin_table, cos_table, algo); +} - func(x_out, x, pos, sin_table, cos_table, algo); +void RoPE::execute(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(RoPE, x_out, x, pos, sin_table, cos_table, algo); } -void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { +void rope_(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { RoPE::execute(x_out, x, pos, sin_table, cos_table, algo); } -Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { - Shape shape = x->shape(); - auto x_out = Tensor::empty(shape, x->dtype(), x->device()); +Tensor rope(const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { + auto x_out = Tensor::empty(x->shape(), x->dtype(), x->device()); rope_(x_out, x, pos, sin_table, cos_table, algo); return x_out; } diff --git a/src/infinicore/ops/rope/rope_infiniop.cc b/src/infinicore/ops/rope/rope_infiniop.cc index 412daa925..850c2d0a2 100644 --- a/src/infinicore/ops/rope/rope_infiniop.cc +++ b/src/infinicore/ops/rope/rope_infiniop.cc @@ -1,69 +1,81 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/rope.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::rope_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopRoPEDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyRoPEDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RoPE, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace; + graph::GraphTensor x_out; + graph::GraphTensor x; + graph::GraphTensor pos; + graph::GraphTensor sin; + graph::GraphTensor cos; +}; -void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) { - // Convert infinicore::nn::RoPE::Algo to infiniopRoPEAlgo_t - infiniopRoPEAlgo_t infiniop_algo; +static infiniopRoPEAlgo_t to_infiniop_algo(infinicore::nn::RoPE::Algo algo) { switch (algo) { case infinicore::nn::RoPE::Algo::GPT_J: - infiniop_algo = INFINIOP_ROPE_ALGO_GPT_J; - break; + return INFINIOP_ROPE_ALGO_GPT_J; case infinicore::nn::RoPE::Algo::GPT_NEOX: - infiniop_algo = INFINIOP_ROPE_ALGO_GPT_NEOX; - break; + return INFINIOP_ROPE_ALGO_GPT_NEOX; default: - throw std::runtime_error("Unsupported RoPE algorithm: " + std::to_string(static_cast(algo))); + throw std::runtime_error("Unsupported RoPE algorithm"); } +} - // Create hash key for descriptor caching - size_t key = hash_combine(x_out, x, pos, sin_cache, cos_cache); - hash_combine(key, std::hash()(static_cast(infiniop_algo))); +void *plan(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin, + const Tensor &cos, + infinicore::nn::RoPE::Algo algo) { + auto infiniop_algo = to_infiniop_algo(algo); + size_t key = hash_combine(x_out, x, pos, sin, cos, static_cast(infiniop_algo)); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, RoPE, key, x_out->desc(), + x->desc(), + pos->desc(), + sin->desc(), + cos->desc(), + infiniop_algo); - auto desc_opt = cache.get(key); - infiniopRoPEDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, RoPE, descriptor); + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x_out), + graph::GraphTensor(x), + graph::GraphTensor(pos), + graph::GraphTensor(sin), + graph::GraphTensor(cos)}; +} - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor( - context::getInfiniopHandle(device), &desc, - x_out->desc(), x->desc(), - pos->desc(), sin_cache->desc(), cos_cache->desc(), - infiniop_algo)); - cache.put(key, desc); - } else { - desc = *desc_opt; - } +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetRoPEWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); + INFINICORE_CHECK_ERROR( + infiniopRoPE( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->x_out->data(), + p->x->data(), + p->pos->data(), + p->sin->data(), + p->cos->data(), + context::getStream())); +} - // InfiniOP reads from x and writes to x_out (handles copying internally) - INFINICORE_CHECK_ERROR(infiniopRoPE( - desc, workspace->data(), workspace_size, - x_out->data(), x->data(), pos->data(), - sin_cache->data(), cos_cache->data(), context::getStream())); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - RoPE::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RoPE, &plan, &run, &cleanup); } // namespace infinicore::op::rope_impl::infiniop diff --git a/src/infinicore/ops/swiglu/swiglu.cc b/src/infinicore/ops/swiglu/swiglu.cc index 5646180e7..8ee0682ad 100644 --- a/src/infinicore/ops/swiglu/swiglu.cc +++ b/src/infinicore/ops/swiglu/swiglu.cc @@ -1,37 +1,26 @@ #include "infinicore/ops/swiglu.hpp" - #include "../../utils.hpp" -#include - namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SwiGLU); -common::OpDispatcher &SwiGLU::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void SwiGLU::execute(Tensor c, Tensor a, Tensor b) { +SwiGLU::SwiGLU(Tensor c, const Tensor &a, const Tensor &b) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); - infinicore::context::setDevice(c->device()); - auto device_type = c->device().getType(); - auto func = dispatcher().lookup(device_type); - - if (func == nullptr) { - throw std::runtime_error("No SwiGLU implementation found for device type: " + std::to_string(static_cast(device_type))); - } + INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b); +} - func(c, a, b); +void SwiGLU::execute(Tensor c, const Tensor &a, const Tensor &b) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(SwiGLU, c, a, b); } -Tensor swiglu(Tensor a, Tensor b) { - Shape shape = a->shape(); - auto c = Tensor::empty(shape, a->dtype(), a->device()); +Tensor swiglu(const Tensor &a, const Tensor &b) { + auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); swiglu_(c, a, b); return c; } -void swiglu_(Tensor c, Tensor a, Tensor b) { +void swiglu_(Tensor c, const Tensor &a, const Tensor &b) { SwiGLU::execute(c, a, b); } + } // namespace infinicore::op diff --git a/src/infinicore/ops/swiglu/swiglu_infiniop.cc b/src/infinicore/ops/swiglu/swiglu_infiniop.cc index 4a963993b..fbb76b570 100644 --- a/src/infinicore/ops/swiglu/swiglu_infiniop.cc +++ b/src/infinicore/ops/swiglu/swiglu_infiniop.cc @@ -1,50 +1,55 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/swiglu.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::swiglu_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopSwiGLUDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroySwiGLUDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor c, Tensor a, Tensor b) { - size_t seed = hash_combine(c, b, a); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopSwiGLUDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor( - context::getInfiniopHandle(device), &desc, - c->desc(), a->desc(), b->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); - - INFINICORE_CHECK_ERROR(infiniopSwiGLU( - desc, workspace->data(), workspace_size, - c->data(), a->data(), b->data(), context::getStream())); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SwiGLU, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace; + graph::GraphTensor c; + graph::GraphTensor a; + graph::GraphTensor b; +}; + +void *plan(Tensor c, const Tensor &a, const Tensor &b) { + size_t key = hash_combine(c, a, b); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, SwiGLU, + key, c->desc(), a->desc(), b->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, SwiGLU, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(c), + graph::GraphTensor(a), + graph::GraphTensor(b)}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopSwiGLU( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->c->data(), + p->a->data(), + p->b->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - SwiGLU::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SwiGLU, &plan, &run, &cleanup); } // namespace infinicore::op::swiglu_impl::infiniop diff --git a/src/infinicore/utils.hpp b/src/infinicore/utils.hpp index cf8e69789..fd0578a5e 100644 --- a/src/infinicore/utils.hpp +++ b/src/infinicore/utils.hpp @@ -49,6 +49,7 @@ inline struct SpdlogInitializer { + ":" + std::to_string(__LINE__) + "."); \ } \ } \ + infinicore::context::setDevice((FIRST___)->device()); \ } while (0) #define INFINICORE_ASSERT(CONDITION__) \ diff --git a/test/infinicore/graph/attention.py b/test/infinicore/graph/attention.py new file mode 100644 index 000000000..cae70dc04 --- /dev/null +++ b/test/infinicore/graph/attention.py @@ -0,0 +1,356 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import BaseOperatorTest, GenericTestRunner, TensorSpec, TestCase +from framework.tensor import TensorInitializer + +import infinicore + +# Test cases format: (nlayers, batch_size, hidden_size, nhead, nkvhead, dim, seqlen, past_seqlen, max_seqlen) +_TEST_CASES_DATA = [ + (28, 1, 3584, 28, 28, 128, 1, 256, 512), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-4, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-4, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 1e-4, "rtol": 5e-2}, +} +_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16] + + +def parse_test_cases(): + cases = [] + for ( + nlayers, + batch_size, + hidden_size, + nhead, + nkvhead, + dim, + seqlen, + past_seqlen, + max_seqlen, + ) in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + hidden_states = TensorSpec.from_tensor( + (batch_size, seqlen, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + pos_ids = TensorSpec.from_tensor( + (batch_size, seqlen), + dtype=infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=0, + high=max_seqlen, + ) + k_cache = TensorSpec.from_tensor( + (nlayers, batch_size, nkvhead, max_seqlen, dim), + dtype=dtype, + scale=1e-1, + bias=-5e-2, + ) + v_cache = TensorSpec.from_tensor( + (nlayers, batch_size, nkvhead, max_seqlen, dim), + dtype=dtype, + scale=1e-1, + bias=-5e-2, + ) + q_proj_w = TensorSpec.from_tensor( + (nhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + k_proj_w = TensorSpec.from_tensor( + (nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + v_proj_w = TensorSpec.from_tensor( + (nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + o_proj_w = TensorSpec.from_tensor( + (hidden_size, nhead * dim), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + norm_w = TensorSpec.from_tensor( + (hidden_size,), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + sin_table = TensorSpec.from_tensor( + (max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + cos_table = TensorSpec.from_tensor( + (max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + + # Out-of-place + cases.append( + TestCase( + inputs=[ + hidden_states, + pos_ids, + nhead, + nkvhead, + dim, + past_seqlen, + nlayers, + k_cache, + v_cache, + q_proj_w, + k_proj_w, + v_proj_w, + o_proj_w, + norm_w, + sin_table, + cos_table, + ], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="Graph", + ) + ) + + return cases + + +def torch_rope( + q: torch.Tensor, + k: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + pos_ids: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + q, k: [B, H, S, D] + sin, cos: [max_S, D//2] + pos_ids: [B, S] + """ + + def rotate_half(x: torch.Tensor) -> torch.Tensor: + # x: [..., head_dim] + x_even = x[..., 0::2] + x_odd = x[..., 1::2] + return torch.stack((-x_odd, x_even), dim=-1).flatten(-2) + + B, H, S, D = q.shape + assert D % 2 == 0 + + # Gather sin/cos by position + # -> [B, S, D//2] + sin = sin[pos_ids] + cos = cos[pos_ids] + + # Expand to broadcast over heads + # -> [B, 1, S, D//2] + sin = sin.unsqueeze(1) + cos = cos.unsqueeze(1) + + # Interleave to full dim + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + + # Apply RoPE + q_rot = (q * cos) + (rotate_half(q) * sin) + k_rot = (k * cos) + (rotate_half(k) * sin) + + return q_rot, k_rot + + +class OpTest(BaseOperatorTest): + """Test Operator Graph""" + + def __init__(self): + super().__init__("Graph") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator( + self, + hidden_states, + pos_ids, + nhead, + nkvhead, + dim, + past_seqlen, + nlayers, + k_cache, + v_cache, + q_proj_w, + k_proj_w, + v_proj_w, + o_proj_w, + norm_w, + sin_table, + cos_table, + **kwargs, + ): + B, S, D = hidden_states.shape + + for layer in range(nlayers): + # ---- RMSNorm ---- + var = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(var + 1e-5) * norm_w + + # ---- QKV projection ---- + q = hidden_states @ q_proj_w.T + k = hidden_states @ k_proj_w.T + v = hidden_states @ v_proj_w.T + + q = q.view(B, S, nhead, dim).transpose(1, 2) # [B,H,S,Dh] + k = k.view(B, S, nkvhead, dim).transpose(1, 2) + v = v.view(B, S, nkvhead, dim).transpose(1, 2) + + # ---- RoPE ---- + q, k = torch_rope( + q, + k, + sin_table, + cos_table, + pos_ids, + ) + + # ---- KV cache update ---- + k_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = k + v_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = v + + K = k_cache[layer, :, :, 0 : past_seqlen + S, :] + V = v_cache[layer, :, :, 0 : past_seqlen + S, :] + + # ---- Scaled Dot Product Attention (fused) ---- + def scaled_dot_product_attention( + query, key, value, is_causal=False, enable_gqa=False + ) -> torch.Tensor: + S, L = query.size(-2), key.size(-2) + scale_factor = query.size(-1) ** -0.5 + attn_bias = torch.zeros(S, L, dtype=query.dtype, device=query.device) + if is_causal: + mask = torch.tril(attn_bias + 1, diagonal=-1).flip(dims=[-2, -1]) + attn_bias = torch.where(mask == 1, -torch.inf, 0.0) + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave( + query.size(-3) // value.size(-3), -3 + ) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value + + attn_out = scaled_dot_product_attention( + q, + K, + V, + is_causal=True, + enable_gqa=True, + ) # [B,H,S,Dh] + + # ---- Output projection ---- + attn_out = attn_out.transpose(1, 2).contiguous() + attn_out = attn_out.view(B, S, nhead * dim) + + hidden_states = attn_out @ o_proj_w.T + + return hidden_states + + def infinicore_operator( + self, + hidden_states, + pos_ids, + nhead, + nkvhead, + dim, + past_seqlen, + nlayers, + k_cache, + v_cache, + q_proj_w, + k_proj_w, + v_proj_w, + o_proj_w, + norm_w, + sin_table, + cos_table, + **kwargs, + ): + """Record graph and run""" + input_hidden_states = hidden_states + B, S, D = input_hidden_states.shape + + infinicore.start_graph_recording() + for layer in range(nlayers): + hidden_states = infinicore.nn.functional.rms_norm( + hidden_states, norm_w.shape, norm_w, 1e-5 + ) + q = infinicore.nn.functional.linear(hidden_states, q_proj_w) + k = infinicore.nn.functional.linear(hidden_states, k_proj_w) + v = infinicore.nn.functional.linear(hidden_states, v_proj_w) + + q = q.view((B, S, nhead, dim)) + k = k.view((B, S, nkvhead, dim)) + v = v.view((B, S, nkvhead, dim)) + q = infinicore.nn.functional.rope( + q, + pos_ids, + sin_table, + cos_table, + infinicore.nn.functional.RopeAlgo.GPT_J, + ) + k = infinicore.nn.functional.rope( + k, + pos_ids, + sin_table, + cos_table, + infinicore.nn.functional.RopeAlgo.GPT_J, + ) + + # [B, KVH, total_len, D] + full_k = ( + k_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S) + ) + full_v = ( + v_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S) + ) + full_k.narrow(2, past_seqlen, S).copy_(k.permute((0, 2, 1, 3))) + full_v.narrow(2, past_seqlen, S).copy_(v.permute((0, 2, 1, 3))) + + G = nhead // nkvhead + L = past_seqlen + S + + full_q = ( + q.permute((0, 2, 1, 3)).contiguous().view((B * nkvhead, G * S, dim)) + ) + full_k = full_k.view((B * nkvhead, L, dim)) + full_v = full_v.view((B * nkvhead, L, dim)) + + attn_score = infinicore.matmul( + full_q, full_k.permute((0, 2, 1)), alpha=dim**-0.5 + ) + # [B * H, S, total_len] + attn_score = attn_score.view((B * nhead, S, L)) + infinicore.nn.functional.causal_softmax(attn_score, out=attn_score) + attn_out = infinicore.matmul(attn_score, full_v) + attn_out = ( + attn_out.view((B, nhead, S, dim)) + .permute((0, 2, 1, 3)) + .contiguous() + .view((B, S, nhead * dim)) + ) + hidden_states = infinicore.nn.functional.linear(attn_out, o_proj_w) + + op_graph = infinicore.stop_graph_recording() + + op_graph.run() + return hidden_states + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/graph/graph.py b/test/infinicore/graph/graph.py deleted file mode 100644 index 2f8927110..000000000 --- a/test/infinicore/graph/graph.py +++ /dev/null @@ -1,85 +0,0 @@ -import sys -import os - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -import torch -import infinicore -from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner - -# Test cases format: (in_shape, proj_w_shape) -_TEST_CASES_DATA = [ - ((32, 4096), (4096, 4096)), -] - -_TOLERANCE_MAP = { - infinicore.float16: {"atol": 0, "rtol": 1e-2}, - infinicore.float32: {"atol": 1e-4, "rtol": 1e-3}, - infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, -} -_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16] - - -def parse_test_cases(): - cases = [] - for in_shape, proj_w_shape in _TEST_CASES_DATA: - for dtype in _TENSOR_DTYPES: - tol = _TOLERANCE_MAP[dtype] - in_spec = TensorSpec.from_tensor(in_shape, dtype=dtype) - proj_w_spec = TensorSpec.from_tensor(proj_w_shape, dtype=dtype) - temp_spec = TensorSpec.from_tensor(in_shape, dtype=dtype) - - # Out-of-place - cases.append( - TestCase( - inputs=[in_spec, proj_w_spec, temp_spec], - kwargs={}, - output_spec=None, - comparison_target=None, - tolerance=tol, - description="Graph", - ) - ) - - return cases - - -class OpTest(BaseOperatorTest): - """Test Operator Graph""" - - def __init__(self): - super().__init__("Graph") - - def get_test_cases(self): - return parse_test_cases() - - def torch_operator(self, *args, **kwargs): - a = args[0] - b = args[1] - - return torch.matmul(a, b) - - def infinicore_operator(self, *args, **kwargs): - """Record graph and run""" - a = args[0] - b = args[1] - temp_a = args[2] - - infinicore.start_graph_recording() - c = infinicore.matmul(temp_a, b) - op_graph = infinicore.stop_graph_recording() - - temp_a.copy_(a) - op_graph.run() - - return c - - -def main(): - """Main entry point""" - runner = GenericTestRunner(OpTest) - runner.run_and_exit() - - -if __name__ == "__main__": - main() From 7c5aa160e9330845a3c17d3d3d0e9b41f79be043 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 06:54:52 +0000 Subject: [PATCH 12/25] issue/985 - adjust cxflags and cxxflags for lua scripts --- xmake.lua | 10 ++++++---- xmake/ascend.lua | 3 +++ xmake/bang.lua | 3 +++ xmake/cpu.lua | 8 +++++--- xmake/hygon.lua | 5 ++++- xmake/iluvatar.lua | 5 ++++- xmake/kunlun.lua | 3 +++ xmake/metax.lua | 12 +++++++++++- xmake/moore.lua | 3 +++ xmake/nvidia.lua | 3 +++ xmake/qy.lua | 3 +++ xmake/test.lua | 2 +- 12 files changed, 49 insertions(+), 11 deletions(-) diff --git a/xmake.lua b/xmake.lua index a51435325..a8e767723 100644 --- a/xmake.lua +++ b/xmake.lua @@ -19,7 +19,7 @@ end if is_plat("windows") then set_runtimes("MD") add_ldflags("/utf-8", {force = true}) - add_cxflags("/utf-8", {force = true}) + add_cxxflags("/utf-8", {force = true}) end -- CPU @@ -224,14 +224,15 @@ target("infini-utils") set_warnings("all", "error") if is_plat("windows") then - add_cxflags("/wd4068") + add_cxxflags("/wd4068") if has_config("omp") then - add_cxflags("/openmp") + add_cxxflags("/openmp") end else add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp", {force = true}) end end @@ -276,6 +277,7 @@ target("infinirt") set_languages("cxx17") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")) add_files("src/infinirt/*.cc") diff --git a/xmake/ascend.lua b/xmake/ascend.lua index 6a28979b4..e51626d1d 100644 --- a/xmake/ascend.lua +++ b/xmake/ascend.lua @@ -44,6 +44,7 @@ target("infiniop-ascend") on_install(function (target) end) add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") set_languages("cxx17") @@ -62,6 +63,7 @@ target("infinirt-ascend") -- Add files add_files("$(projectdir)/src/infinirt/ascend/*.cc") add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-ascend") @@ -76,5 +78,6 @@ target("infiniccl-ascend") add_links("libhccl.so") add_files("../src/infiniccl/ascend/*.cc") add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/bang.lua b/xmake/bang.lua index d2195acd5..ffa85ef6d 100644 --- a/xmake/bang.lua +++ b/xmake/bang.lua @@ -41,6 +41,7 @@ target("infiniop-cambricon") on_install(function (target) end) add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") set_languages("cxx17") @@ -59,6 +60,7 @@ target("infinirt-cambricon") -- Add include dirs add_files("../src/infinirt/bang/*.cc") add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-cambricon") @@ -89,6 +91,7 @@ target("infiniccl-cambricon") add_files("../src/infiniccl/cambricon/*.cc") add_cxflags("-fPIC") + add_cxxflags("-fPIC") add_ldflags("-fPIC") else print("[Warning] CNCL is currently only supported on Linux") diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 22dc8f8e7..e192fbbbd 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -6,14 +6,15 @@ target("infiniop-cpu") set_warnings("all", "error") if is_plat("windows") then - add_cxflags("/wd4068") + add_cxxflags("/wd4068") if has_config("omp") then - add_cxflags("/openmp") + add_cxxflags("/openmp") end else add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp") end end @@ -32,6 +33,7 @@ target("infinirt-cpu") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") diff --git a/xmake/hygon.lua b/xmake/hygon.lua index ed4b91f0e..05d3e8356 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -60,6 +60,7 @@ target("infiniop-hygon") add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") @@ -76,7 +77,7 @@ target("infiniop-hygon") add_files("../src/infiniop/ops/swiglu/nvidia/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) end target_end() @@ -105,6 +106,7 @@ target("infinirt-hygon") add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") @@ -138,6 +140,7 @@ target("infiniccl-hygon") add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index 57a935f4f..cd9304127 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -49,6 +49,7 @@ target("infiniop-iluvatar") end add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu") @@ -57,7 +58,7 @@ target("infiniop-iluvatar") add_files("../src/infiniop/ops/dequantize_awq/iluvatar/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) end target_end() @@ -76,6 +77,7 @@ target("infinirt-iluvatar") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 add_files("../src/infinirt/cuda/*.cu") @@ -97,6 +99,7 @@ target("infiniccl-iluvatar") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/kunlun.lua b/xmake/kunlun.lua index 185082b3c..84ba14082 100644 --- a/xmake/kunlun.lua +++ b/xmake/kunlun.lua @@ -75,6 +75,7 @@ target("infiniop-kunlun") on_install(function (target) end) add_cxflags("-lstdc++ -fPIC -Wno-error=unused-function") + add_cxxflags("-lstdc++ -fPIC -Wno-error=unused-function") set_warnings("all", "error") set_languages("cxx17") @@ -102,6 +103,7 @@ target("infinirt-kunlun") -- Add include dirs add_files("$(projectdir)/src/infinirt/kunlun/*.cc") add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-kunlun") @@ -117,5 +119,6 @@ target("infiniccl-kunlun") add_links("bkcl") add_files("$(projectdir)/src/infiniccl/kunlun/*.cc") add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/metax.lua b/xmake/metax.lua index 5561b45db..4ee7e0895 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -48,11 +48,19 @@ target("infiniop-metax") set_languages("cxx17") set_warnings("all", "error") add_cxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) + add_cxxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) add_files("../src/infiniop/devices/metax/*.cc", "../src/infiniop/ops/*/metax/*.cc") add_files("../src/infiniop/ops/*/metax/*.maca", {rule = "maca"}) if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-include stdlib.h", "-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", { + cxflags = { + "-include stdlib.h", + "-Wno-return-type", + "-Wno-implicit-function-declaration", + "-Wno-builtin-declaration-mismatch" + } + }) end target_end() @@ -63,6 +71,7 @@ target("infinirt-metax") add_deps("infini-utils") set_warnings("all", "error") add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") add_files("../src/infinirt/metax/*.cc") target_end() @@ -73,6 +82,7 @@ target("infiniccl-metax") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end if has_config("ccl") then if has_config("use-mc") then diff --git a/xmake/moore.lua b/xmake/moore.lua index 25eddf522..fdcad9564 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -42,6 +42,7 @@ target("infiniop-moore") set_languages("cxx17") set_warnings("all", "error") add_cxflags("-lstdc++", "-fPIC", "-Wno-comment") + add_cxxflags("-lstdc++", "-fPIC", "-Wno-comment") add_files("../src/infiniop/devices/moore/*.cc") add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"}) @@ -56,6 +57,7 @@ target("infinirt-moore") add_deps("infini-utils") set_warnings("all", "error") add_cxflags("-lstdc++", "-fPIC") + add_cxxflags("-lstdc++", "-fPIC") add_files("../src/infinirt/moore/*.cc") target_end() @@ -66,6 +68,7 @@ target("infiniccl-moore") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end if has_config("ccl") then add_links("libmccl.so") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index a1133b15b..5e9eef5f3 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -48,6 +48,7 @@ target("infiniop-nvidia") add_cuflags("-Xcompiler=-fPIC") add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") add_cflags("-fPIC") add_cuflags("--expt-relaxed-constexpr") @@ -93,6 +94,7 @@ target("infinirt-nvidia") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") @@ -112,6 +114,7 @@ target("infiniccl-nvidia") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/qy.lua b/xmake/qy.lua index 4a512e203..bd591249a 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -88,6 +88,7 @@ target("infiniop-qy") add_cuflags("-Xcompiler=-fPIC") add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") add_cuflags("--expt-relaxed-constexpr") if CUDNN_ROOT ~= nil then @@ -117,6 +118,7 @@ target("infinirt-qy") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") @@ -133,6 +135,7 @@ target("infiniccl-qy") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/test.lua b/xmake/test.lua index 002083e1d..56dca6e5f 100644 --- a/xmake/test.lua +++ b/xmake/test.lua @@ -24,7 +24,7 @@ target("infiniop-test") add_links("infiniop", "infinirt") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp") end From 55cd22e312d497a203f7b985f5375810665b7e96 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 25 Aug 2025 20:04:27 +0800 Subject: [PATCH 13/25] issue/402 - convenient ninetoothed util MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 对 `NineToothedTensor` 进行 C++ 层封装 加入使用数组作为 `shape` 和 `strides` 创建 `ninetoothed::Tensor` 的方式 使用 `ninetoothed::Tensor` 接入九齿的 ReLU 算子 Add an include guard to `ninetoothed/utils.h` --- src/infiniop/ninetoothed/utils.h | 75 +++++++++++++++++++++ src/infiniop/ops/relu/metax/relu_metax.maca | 21 ++---- src/infiniop/ops/relu/nvidia/relu_nvidia.cu | 21 ++---- 3 files changed, 85 insertions(+), 32 deletions(-) create mode 100644 src/infiniop/ninetoothed/utils.h diff --git a/src/infiniop/ninetoothed/utils.h b/src/infiniop/ninetoothed/utils.h new file mode 100644 index 000000000..1b7d1fe3a --- /dev/null +++ b/src/infiniop/ninetoothed/utils.h @@ -0,0 +1,75 @@ +#ifndef __NINETOOTHED_UTILS__ +#define __NINETOOTHED_UTILS__ + +#include +#include +#include +#include + +namespace ninetoothed { + +template +class Tensor { +public: + using Data = decltype(NineToothedTensor::data); + + using Size = std::remove_pointer_t; + + using Stride = std::remove_pointer_t; + + template + Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {} + + Tensor(const void *data, std::initializer_list shape, std::initializer_list strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {} + + Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {} + + Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {} + + operator NineToothedTensor() { return {const_cast(data_), shape_.data(), strides_.data()}; } + + template + Tensor expand(const Shape &sizes) const { + auto new_ndim{sizes.size()}; + + decltype(shape_) shape(new_ndim, 1); + decltype(strides_) strides(new_ndim, 0); + + auto num_new_dims{new_ndim - ndim_}; + + for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) { + shape[dim + num_new_dims] = shape_[dim]; + strides[dim + num_new_dims] = strides_[dim]; + } + + for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) { + if (sizes[dim] == std::numeric_limits>::max() || shape[dim] != 1) { + continue; + } + + shape[dim] = sizes[dim]; + strides[dim] = 0; + } + + return {data_, shape, strides}; + } + + Tensor expand_as(const Tensor &other) const { + return expand(other.shape_); + } + +private: + const void *data_{nullptr}; + + std::vector shape_; + + std::vector strides_; + + Size ndim_{0}; + + T value_{0}; +}; + +} // namespace ninetoothed + +#endif diff --git a/src/infiniop/ops/relu/metax/relu_metax.maca b/src/infiniop/ops/relu/metax/relu_metax.maca index 900fce9e0..2c5104bdd 100644 --- a/src/infiniop/ops/relu/metax/relu_metax.maca +++ b/src/infiniop/ops/relu/metax/relu_metax.maca @@ -2,6 +2,7 @@ #include "../../../../../build/ninetoothed/relu.h" #include "../../../devices/metax/metax_common.h" +#include "../../../ninetoothed/utils.h" #include "relu_metax.h" namespace op::relu::metax { @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( } const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { diff --git a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu index 22b85a401..a3c79fb52 100644 --- a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu +++ b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu @@ -1,5 +1,6 @@ #ifdef ENABLE_NINETOOTHED #include "../../../../../build/ninetoothed/relu.h" +#include "../../../ninetoothed/utils.h" #endif #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../elementwise/nvidia/elementwise_nvidia.cuh" @@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate( } #ifdef ENABLE_NINETOOTHED const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { From 32340fc34588bff691a9a816e8dfa9ec02cb3f37 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 Jan 2026 07:22:50 +0000 Subject: [PATCH 14/25] issue/925 - Speed up `scripts/build_ntops.py` and `src/infiniop/ninetoothed/build.py` with `concurrent.futures` --- scripts/build_ntops.py | 28 +++++++---- src/infiniop/ninetoothed/build.py | 77 +++++++++++++++++++------------ 2 files changed, 68 insertions(+), 37 deletions(-) diff --git a/scripts/build_ntops.py b/scripts/build_ntops.py index 1499b6bf8..e1397e56d 100644 --- a/scripts/build_ntops.py +++ b/scripts/build_ntops.py @@ -1,3 +1,4 @@ +import concurrent.futures import importlib import pathlib @@ -11,16 +12,27 @@ def _find_and_build_ops(): ops_path = SRC_DIR_PATH / "infiniop" / "ops" - for op_dir in ops_path.iterdir(): - ninetoothed_path = op_dir / "ninetoothed" + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - if ninetoothed_path.is_dir(): - module_path = ninetoothed_path / "build" - relative_path = module_path.relative_to(SRC_DIR_PATH) - import_name = ".".join(relative_path.parts) - module = importlib.import_module(import_name) + for op_dir in ops_path.iterdir(): + ninetoothed_path = op_dir / "ninetoothed" - module.build() + if not ninetoothed_path.is_dir(): + continue + + futures.append(executor.submit(_build, ninetoothed_path)) + + concurrent.futures.as_completed(futures) + + +def _build(ninetoothed_path): + module_path = ninetoothed_path / "build" + relative_path = module_path.relative_to(SRC_DIR_PATH) + import_name = ".".join(relative_path.parts) + module = importlib.import_module(import_name) + + module.build() if __name__ == "__main__": diff --git a/src/infiniop/ninetoothed/build.py b/src/infiniop/ninetoothed/build.py index aea421b7f..153e6b9f5 100644 --- a/src/infiniop/ninetoothed/build.py +++ b/src/infiniop/ninetoothed/build.py @@ -1,3 +1,4 @@ +import concurrent.futures import functools import inspect import itertools @@ -16,40 +17,28 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): headers = [] all_param_names = [] + combinations = [] launches = [] - for combination in _generate_param_value_combinations(constexpr_param_grid): - arrangement, application, tensors = premake(**combination) + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - for param_name, param_value in combination.items(): - if isinstance(param_value, str): - combination[param_name] = ( - f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" - ) + for combination in tuple( + _generate_param_value_combinations(constexpr_param_grid) + ): + future = executor.submit( + _make, premake, combination, caller, op_name, output_dir + ) - combination = {f"{name}_": value for name, value in combination.items()} + futures.append(future) - kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + for future in concurrent.futures.as_completed(futures): + header, param_names, combination, launch = future.result() - ninetoothed.make( - arrangement, - application, - tensors, - caller=caller, - kernel_name=kernel_name, - output_dir=output_dir, - ) - - header = output_dir / f"{kernel_name}.h" - param_names = ("stream",) + tuple( - inspect.signature(application).parameters.keys() - ) - launch = f""" if ({_generate_condition(combination)}) - return launch_{kernel_name}({", ".join(param_names)});""" - - headers.append(header) - all_param_names.append(param_names) - launches.append(launch) + headers.append(header) + all_param_names.append(param_names) + combinations.append(combination) + launches.append(launch) includes = "\n".join(f'#include "{header}"' for header in headers) @@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): "NineToothedStream", ] + ["NineToothedTensor" for _ in range(len(param_names) - 1)] - for param_name in combination: + for param_name in functools.reduce(lambda x, y: x | y, combinations, {}): param_names.append(param_name) param_types.append("int") @@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): (BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content) +def _make(premake, combination, caller, op_name, output_dir): + arrangement, application, tensors = premake(**combination) + + for param_name, param_value in combination.items(): + if isinstance(param_value, str): + combination[param_name] = ( + f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" + ) + + combination = {f"{name}_": value for name, value in combination.items()} + + kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + + ninetoothed.make( + arrangement, + application, + tensors, + caller=caller, + kernel_name=kernel_name, + output_dir=output_dir, + ) + + header = output_dir / f"{kernel_name}.h" + param_names = ("stream",) + tuple(inspect.signature(application).parameters.keys()) + launch = f""" if ({_generate_condition(combination)}) + return launch_{kernel_name}({", ".join(param_names)});""" + + return header, param_names, combination, launch + + def _generate_condition(combination): return " && ".join(f"{param} == {value}" for param, value in combination.items()) From ca58118f04a26f8157ed757a60713c1a2f75c945 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 03:25:05 +0000 Subject: [PATCH 15/25] issue/940 - check build result and implicitly require build.py for build ntops --- scripts/build_ntops.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/build_ntops.py b/scripts/build_ntops.py index e1397e56d..601249615 100644 --- a/scripts/build_ntops.py +++ b/scripts/build_ntops.py @@ -21,9 +21,14 @@ def _find_and_build_ops(): if not ninetoothed_path.is_dir(): continue + build_file = ninetoothed_path / "build.py" + if not build_file.exists(): + continue + futures.append(executor.submit(_build, ninetoothed_path)) - concurrent.futures.as_completed(futures) + for future in concurrent.futures.as_completed(futures): + future.result() def _build(ninetoothed_path): From 47843aa68e4024ee58854c72a3218caf6306e3d8 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 15 Jan 2026 12:15:06 +0000 Subject: [PATCH 16/25] issue/935 - add metax include dir for ninetoothed --- xmake/metax.lua | 1 + 1 file changed, 1 insertion(+) diff --git a/xmake/metax.lua b/xmake/metax.lua index 4ee7e0895..65e5d549b 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -53,6 +53,7 @@ target("infiniop-metax") add_files("../src/infiniop/ops/*/metax/*.maca", {rule = "maca"}) if has_config("ninetoothed") then + add_includedirs(MACA_ROOT .. "/include/mcr") add_files("../build/ninetoothed/*.c", { cxflags = { "-include stdlib.h", From 6ac8f9065b70249e938dfff1bebebe12f316c583 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 11:21:39 +0000 Subject: [PATCH 17/25] issue/919 - ninetoothed flash attention --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/flash_attention.hpp | 12 + include/infiniop.h | 1 + include/infiniop/ops/flash_attention.h | 36 +++ python/infinicore/nn/functional/__init__.py | 10 +- .../nn/functional/flash_attention.py | 34 +++ .../ops/flash_attention/flash_attention.cc | 31 ++ .../flash_attention_infiniop.cc | 55 ++++ src/infinicore/pybind11/ops.hpp | 4 +- .../pybind11/ops/flash_attention.hpp | 22 ++ .../ops/flash_attention/ninetoothed/build.py | 46 +++ .../flash_attention/ninetoothed/descriptor.h | 147 +++++++++ .../ninetoothed/flash_attention.py | 281 ++++++++++++++++++ src/infiniop/ops/flash_attention/operator.cc | 121 ++++++++ test/infinicore/ops/flash_attention.py | 115 +++++++ 15 files changed, 911 insertions(+), 5 deletions(-) create mode 100644 include/infinicore/ops/flash_attention.hpp create mode 100644 include/infiniop/ops/flash_attention.h create mode 100644 python/infinicore/nn/functional/flash_attention.py create mode 100644 src/infinicore/ops/flash_attention/flash_attention.cc create mode 100644 src/infinicore/ops/flash_attention/flash_attention_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/flash_attention.hpp create mode 100644 src/infiniop/ops/flash_attention/ninetoothed/build.py create mode 100644 src/infiniop/ops/flash_attention/ninetoothed/descriptor.h create mode 100644 src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py create mode 100644 src/infiniop/ops/flash_attention/operator.cc create mode 100644 test/infinicore/ops/flash_attention.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 88dd5b342..772bc030e 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -5,6 +5,7 @@ #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" +#include "ops/flash_attention.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" diff --git a/include/infinicore/ops/flash_attention.hpp b/include/infinicore/ops/flash_attention.hpp new file mode 100644 index 000000000..24e33cfb6 --- /dev/null +++ b/include/infinicore/ops/flash_attention.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool); + +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 378a79a43..092868923 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -10,6 +10,7 @@ #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" #include "infiniop/ops/embedding.h" +#include "infiniop/ops/flash_attention.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/layer_norm.h" diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h new file mode 100644 index 000000000..5ea71335b --- /dev/null +++ b/include/infiniop/ops/flash_attention.h @@ -0,0 +1,36 @@ +#ifndef __INFINIOP_FLASH_ATTENTION_API_H__ +#define __INFINIOP_FLASH_ATTENTION_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + float scale, + char is_causal); + +__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + const void *total_kv_len, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc); +#endif diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..e1ae309f5 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,5 +1,6 @@ from .causal_softmax import causal_softmax from .embedding import embedding +from .flash_attention import flash_attention from .linear import linear from .random_sample import random_sample from .rms_norm import rms_norm @@ -9,12 +10,13 @@ __all__ = [ "causal_softmax", + "embedding", + "flash_attention", + "linear", "random_sample", "rms_norm", + "RopeAlgo", + "rope", "silu", "swiglu", - "linear", - "embedding", - "rope", - "RopeAlgo", ] diff --git a/python/infinicore/nn/functional/flash_attention.py b/python/infinicore/nn/functional/flash_attention.py new file mode 100644 index 000000000..8f42e865f --- /dev/null +++ b/python/infinicore/nn/functional/flash_attention.py @@ -0,0 +1,34 @@ +import math + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def flash_attention( + query, + key, + value, + total_kv_len, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, +): + assert attn_mask is None and dropout_p == 0 and not enable_gqa + + emb_dim = query.shape[-1] + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + return Tensor( + _infinicore.flash_attention( + query._underlying, + key._underlying, + value._underlying, + total_kv_len._underlying, + scale, + is_causal, + ) + ) diff --git a/src/infinicore/ops/flash_attention/flash_attention.cc b/src/infinicore/ops/flash_attention/flash_attention.cc new file mode 100644 index 000000000..21cd56010 --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention.cc @@ -0,0 +1,31 @@ +#include "infinicore/ops/flash_attention.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention); + +FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k, v, total_kv_len, scale, is_causal); +} + +void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal); +} + +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + Shape shape = q->shape(); + int idx = shape.size() - 1; + shape[idx] = v->shape()[idx]; + auto out = Tensor::empty(shape, q->dtype(), q->device()); + flash_attention_(out, q, k, v, total_kv_len, scale, is_causal); + return out; +} + +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc new file mode 100644 index 000000000..f5207f0ee --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc @@ -0,0 +1,55 @@ +#include "../../utils.hpp" +#include "../infiniop_impl.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/flash_attention.hpp" +#include + +namespace infinicore::op::flash_attention_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, q, k, v, total_kv_len; + float scale; + bool is_causal; +}; + +void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, FlashAttention, + seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal); + + INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor); + + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(total_kv_len), scale, is_causal}; + + return planned; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopFlashAttention( + planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), + planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(FlashAttention, &plan, &run, &cleanup); + +} // namespace infinicore::op::flash_attention_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 3d6ebe79a..c53218990 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -7,6 +7,7 @@ #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" +#include "ops/flash_attention.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -29,13 +30,14 @@ inline void bind(py::module &m) { bind_add_rms_norm(m); bind_attention(m); bind_causal_softmax(m); - bind_random_sample(m); + bind_flash_attention(m); bind_linear(m); bind_matmul(m); bind_mul(m); bind_paged_attention(m); bind_paged_attention_prefill(m); bind_paged_caching(m); + bind_random_sample(m); bind_rearrange(m); bind_rms_norm(m); bind_silu(m); diff --git a/src/infinicore/pybind11/ops/flash_attention.hpp b/src/infinicore/pybind11/ops/flash_attention.hpp new file mode 100644 index 000000000..6e3766796 --- /dev/null +++ b/src/infinicore/pybind11/ops/flash_attention.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "infinicore/ops/flash_attention.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_flash_attention(py::module &m) { + m.def("flash_attention", + &op::flash_attention, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("total_kv_len"), + py::arg("scale"), + py::arg("is_causal")); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py new file mode 100644 index 000000000..23f265e2e --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -0,0 +1,46 @@ +import ninetoothed +from . import flash_attention +from .flash_attention import CausalVariant + +import infiniop.ninetoothed.build + +import torch + + +def build(): + + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + for i in range(device_count): + device_name = torch.cuda.get_device_name(i).lower() + + if "metax" in device_name: + return + + with_kv_cache_values = (0,) + emb_dim_values = (16, 32, 64, 128, 256) + is_causal_values = (0, 1) + with_attn_mask_values = (0,) + causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT) + dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32) + block_size_m_values = (256,) + block_size_n_values = (64,) + + constexpr_param_grid = { + "with_kv_cache": with_kv_cache_values, + "emb_dim": emb_dim_values, + "is_causal": is_causal_values, + "with_attn_mask": with_attn_mask_values, + "causal_variant": causal_variant_values, + "dtype": dtype_values, + "block_size_m": block_size_m_values, + "block_size_n": block_size_n_values, + } + + infiniop.ninetoothed.build.build( + flash_attention.premake, + constexpr_param_grid, + caller="cuda", + op_name="flash_attention", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h new file mode 100644 index 000000000..0a6e9c1f8 --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -0,0 +1,147 @@ +#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__ +#define __FLASH_ATTENTION_DESCRIPTOR_H__ + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/flash_attention.h" +#include "../../../ninetoothed/utils.h" + +namespace op::flash_attention::ninetoothed { + +class Descriptor final : public InfiniopDescriptor { +public: + Descriptor(infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + double scale, + char is_causal) : InfiniopDescriptor{handle->device, handle->device_id}, + _query_shape{q_desc->shape()}, + _query_strides{q_desc->strides()}, + _key_shape{k_desc->shape()}, + _key_strides{k_desc->strides()}, + _value_shape{v_desc->shape()}, + _value_strides{v_desc->strides()}, + _total_kv_shape{total_kv_len->shape()}, + _total_kv_strides{total_kv_len->strides()}, + _output_strides{out_desc->strides()}, + _dtype{q_desc->dtype()}, + _scale{scale}, + _is_causal{is_causal} { + } + + ~Descriptor() = default; + + size_t get_workspace_size() const { + return 0; + } + + infiniStatus_t calculate(void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + const void *total_kv_len, + void *stream) const { + uint64_t empty_shape[4]; + int64_t empty_strides[4]; + + auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}}; + auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}}; + auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}}; + auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}}; + + NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides}; + NineToothedTensor is_causal; + NineToothedTensor scale{const_cast(&_scale), nullptr, nullptr}; + auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}}; + NineToothedTensor with_attn_mask; + NineToothedTensor causal_variant; + + const auto with_kv_cache_{0}; + const auto emb_dim_{_query_shape[3]}; + const auto is_causal_{_is_causal}; + const auto with_attn_mask_{0}; + const auto causal_variant_{2}; + const auto dtype_{_dtype}; + + constexpr auto block_size_m_{256}; + constexpr auto block_size_n_{64}; + + if (launch_flash_attention(stream, + query, + key, + value, + total_kv_length, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache_, + emb_dim_, + is_causal_, + with_attn_mask_, + causal_variant_, + dtype_, + block_size_m_, + block_size_n_)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + + static infiniStatus_t create(infiniopHandle_t handle, + Descriptor **desc, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + double scale, + char is_causal) { + *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal}; + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector _query_shape; + + std::vector _query_strides; + + std::vector _key_shape; + + std::vector _key_strides; + + std::vector _value_shape; + + std::vector _value_strides; + + std::vector _total_kv_shape; + + std::vector _total_kv_strides; + + std::vector _output_strides; + + infiniDtype_t _dtype; + + double _scale; + + char _is_causal; +}; + +} // namespace op::flash_attention::ninetoothed + +#endif // __FLASH_ATTENTION_DESCRIPTOR_H__ diff --git a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py new file mode 100644 index 000000000..22d63ae4a --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py @@ -0,0 +1,281 @@ +import enum +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +BLOCK_SIZE_M = ninetoothed.block_size() +BLOCK_SIZE_N = ninetoothed.block_size() + + +class CausalVariant(enum.IntEnum): + """Please refer to ``_.""" + + UPPER_LEFT = enum.auto() + + LOWER_RIGHT = enum.auto() + + +def arrangement( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache, + block_size_m=None, + block_size_n=None, +): + def arrange_query_or_output(input): + arranged = input.tile((1, 1, block_size_m, -1)).tile( + (1, query.shape[-3] // key.shape[-3], 1, 1) + ) + arranged.dtype = arranged.dtype.squeeze((0, 2, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_key_or_value(input): + arranged = ( + input.tile((1, 1, block_size_n, -1)) + .tile((1, 1, -1, -1)) + .expand((-1, -1, query_arranged.shape[-2], -1)) + ) + arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_total_kv_len(input, shape): + arranged = input.tile((1,)) + arranged = arranged.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(shape) + return arranged + + def arrange_present_key_or_present_value(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + def arrange_attn_mask(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)).tile((1, 1, 1, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1, 2)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + if block_size_m is None: + block_size_m = BLOCK_SIZE_M + + if block_size_n is None: + block_size_n = BLOCK_SIZE_N + + query_arranged = arrange_query_or_output(query) + key_arranged = arrange_key_or_value(key) + value_arranged = arrange_key_or_value(value) + total_kv_len_arranged = arrange_total_kv_len(total_kv_len, query_arranged.shape) + present_key_arranged = arrange_present_key_or_present_value(present_key) + present_value_arranged = arrange_present_key_or_present_value(present_value) + present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot) + present_value_slot_arranged = arrange_present_key_or_present_value( + present_value_slot + ) + attn_mask_arranged = arrange_attn_mask(attn_mask) + is_causal_arranged = is_causal + scale_arranged = scale + output_arranged = arrange_query_or_output(output) + with_attn_mask_arranged = with_attn_mask + causal_variant_arranged = causal_variant + + if with_kv_cache: + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + present_key_arranged, + present_value_arranged, + present_key_slot_arranged, + present_value_slot_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + +def application_with_kv_cache( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + present_key_slot = present_key # noqa: F841 + present_value_slot = present_value # noqa: F841 + + application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + +def application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + actual_kv_len = total_kv_len[0] + + for i in range(query.shape[0]): + query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype) + + acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32) + lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) + max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) + + for j in range(-(-actual_kv_len // key.dtype.shape[0])): + + qk = ntl.dot(query_i, ntl.trans(key[j])) + + key_pos = key[j].offsets(-2) + qk = ntl.where(key_pos < actual_kv_len, qk, float("-inf")) + + if with_attn_mask: + qk += attn_mask[j] + + if is_causal: + query_pos = query[i].offsets(-2) + + if causal_variant == 2: # CausalVariant.LOWER_RIGHT: + mask = ( + query_pos[:, None] + actual_kv_len - query.source.shape[-2] + >= key_pos[None, :] + ) + else: + mask = query_pos[:, None] >= key_pos[None, :] + + qk = ntl.where(mask, qk, float("-inf")) + + next_max = ntl.maximum(max, ntl.max(qk, 1)) + stable_qk = ntl.exp2(qk - next_max[:, None]) + + alpha = ntl.exp2(max - next_max) + acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j]) + max = next_max + lse = lse * alpha + ntl.sum(stable_qk, 1) + + acc /= lse[:, None] + output[i] = acc # noqa: F841 + + +def premake( + with_kv_cache, + emb_dim=None, + is_causal=None, + with_attn_mask=None, + causal_variant=None, + dtype=None, + block_size_m=None, + block_size_n=None, +): + arrangement_ = functools.partial( + arrangement, + with_kv_cache=with_kv_cache, + block_size_m=block_size_m, + block_size_n=block_size_n, + ) + + query, key, value, attn_mask, output = ( + Tensor( + 4, + dtype=dtype, + shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}), + ) + for _ in range(5) + ) + total_kv_len = Tensor(1, dtype=ninetoothed.int32) + present_key, present_value, present_key_slot, present_value_slot = ( + Tensor(4, dtype=dtype) for _ in range(4) + ) + scale = Tensor(0, dtype=ninetoothed.float64) + is_causal = Tensor(0, constexpr=True, value=is_causal) + with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask) + causal_variant = Tensor(0, constexpr=True, value=causal_variant) + + if emb_dim is not None: + for tensor in (query, key, value, attn_mask, output): + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + if with_kv_cache: + application = application_with_kv_cache + else: + application = application_without_kv_cache + + tensors = ( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc new file mode 100644 index 000000000..53e484dd5 --- /dev/null +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/flash_attention.h" + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) +#include "ninetoothed/descriptor.h" +#endif +#endif + +__C infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + float scale, + char is_causal) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::flash_attention::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + q_desc, \ + k_desc, \ + v_desc, \ + total_kv_len, \ + scale, \ + is_causal); + + switch (handle->device) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + const void *total_kv_len, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream); + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/test/infinicore/ops/flash_attention.py b/test/infinicore/ops/flash_attention.py new file mode 100644 index 000000000..2d4b09599 --- /dev/null +++ b/test/infinicore/ops/flash_attention.py @@ -0,0 +1,115 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TensorInitializer, + TestCase, + GenericTestRunner, +) + +# Test cases format: (q_shape, k_shape, v_shape, attn_mask_or_None, dropout_p, is_causal) +# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) + +_TEST_CASES_DATA = [ + ((1, 1, 2, 16), (1, 1, 8, 16), (1, 1, 8, 16), None, 0.0, False), + ((1, 2, 128, 16), (1, 2, 256, 16), (1, 2, 256, 16), None, 0.0, False), + ((1, 1, 4, 32), (1, 1, 32, 32), (1, 1, 32, 32), None, 0.0, True), + ((1, 8, 256, 16), (1, 8, 512, 16), (1, 8, 512, 16), None, 0.0, True), + ((1, 8, 4, 16), (1, 8, 64, 16), (1, 8, 64, 16), None, 0.0, False), + ((8, 28, 256, 128), (8, 28, 512, 128), (8, 28, 512, 128), None, 0.0, True), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, +} +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + import random + + cases = [] + for q_shape, k_shape, v_shape, attn_mask, dropout_p, is_causal in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + q_spec = TensorSpec.from_tensor(q_shape, None, dtype) + k_spec = TensorSpec.from_tensor(k_shape, None, dtype) + v_spec = TensorSpec.from_tensor(v_shape, None, dtype) + + len_shape = (q_shape[0],) + total_len = random.randint(1, k_shape[2]) + total_kv_len_spec = TensorSpec.from_tensor( + len_shape, + None, + infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=total_len, + high=total_len + 1, + ) + + kwargs = { + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + } + # remove None keys + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + cases.append( + TestCase( + inputs=[q_spec, k_spec, v_spec, total_kv_len_spec, total_len], + kwargs=kwargs, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="Flash Attention", + ) + ) + + return cases + + +def torch_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + k_slice = k[:, :, :cheat, :] + v_slice = v[:, :, :cheat, :] + return torch.nn.functional.scaled_dot_product_attention( + q, k_slice, v_slice, **kwargs + ) + + +def infini_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + return infinicore.nn.functional.flash_attention(q, k, v, total_kv_len, **kwargs) + + +class OpTest(BaseOperatorTest): + """ScaledDotProductAttention operator test with simplified implementation""" + + def __init__(self): + super().__init__("ScaledDotProductAttention") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_flash_attn(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infini_flash_attn(*args, **kwargs) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() From 5614e1be5228862b6db0ef0c40d534f14edc5577 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 06:47:43 +0000 Subject: [PATCH 18/25] issue/931 - ninetoothed swiglu for nv, il, mtx --- src/infiniop/ops/swiglu/ninetoothed/build.py | 29 +++++++ src/infiniop/ops/swiglu/ninetoothed/swiglu.h | 82 +++++++++++++++++++ src/infiniop/ops/swiglu/ninetoothed/swiglu.py | 22 +++++ src/infiniop/ops/swiglu/operator.cc | 56 +++++++++++++ 4 files changed, 189 insertions(+) create mode 100644 src/infiniop/ops/swiglu/ninetoothed/build.py create mode 100644 src/infiniop/ops/swiglu/ninetoothed/swiglu.h create mode 100644 src/infiniop/ops/swiglu/ninetoothed/swiglu.py diff --git a/src/infiniop/ops/swiglu/ninetoothed/build.py b/src/infiniop/ops/swiglu/ninetoothed/build.py new file mode 100644 index 000000000..fa4af6db2 --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/build.py @@ -0,0 +1,29 @@ +import ninetoothed +from . import swiglu + +import infiniop.ninetoothed.build + + +def build(): + MAX_NDIM = 5 + + ndim_values = range(1, MAX_NDIM + 1) + dtype_values = ( + ninetoothed.float16, + ninetoothed.bfloat16, + ninetoothed.float32, + ) + + constexpr_param_grid = { + "ndim": ndim_values, + "dtype": dtype_values, + "block_size": (1024,), + } + + infiniop.ninetoothed.build.build( + swiglu.premake, + constexpr_param_grid, + caller="cuda", + op_name="swiglu", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.h b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h new file mode 100644 index 000000000..4aa2fa70e --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h @@ -0,0 +1,82 @@ +#ifndef SWIGLU_H +#define SWIGLU_H + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/swiglu.h" +#include "../../../ninetoothed/utils.h" + +namespace op::swiglu::ninetoothed { +class Descriptor final : public InfiniopDescriptor { + +public: + Descriptor( + infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id}, + out_shape_{out_desc->shape()}, + out_strides_{out_desc->strides()}, + up_shape_{input_desc_vec[0]->shape()}, + up_strides_{input_desc_vec[0]->strides()}, + gate_shape_{input_desc_vec[1]->shape()}, + gate_strides_{input_desc_vec[1]->strides()}, + dtype_{out_desc->dtype()} {} + + ~Descriptor() = default; + + size_t workspaceSize() const { + return 0; + } + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + *desc_ptr = new Descriptor(handle, out_desc, input_desc_vec); + return INFINI_STATUS_SUCCESS; + } + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)}; + auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)}; + auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)}; + + if (launch_swiglu(stream, + out_nt, + up_nt, + gate_nt, + out_shape_.size(), + dtype_, + 1024)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector out_shape_; + std::vector out_strides_; + + std::vector up_shape_; + std::vector up_strides_; + + std::vector gate_shape_; + std::vector gate_strides_; + + infiniDtype_t dtype_; +}; +} // namespace op::swiglu::ninetoothed + +#endif // SWIGLU_H diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.py b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py new file mode 100644 index 000000000..62074a84b --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(output, up, gate): + output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index 9d8e6406a..b3fabba32 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -6,14 +6,22 @@ #include "cpu/swiglu_cpu.h" #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "nvidia/swiglu_nvidia.cuh" #endif +#endif #ifdef ENABLE_KUNLUN_API #include "kunlun/swiglu_kunlun.h" #endif #ifdef ENABLE_METAX_API +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "metax/swiglu_metax.h" #endif +#endif #ifdef ENABLE_CAMBRICON_API #include "bang/swiglu_bang.h" #endif @@ -46,11 +54,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -61,8 +77,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_METAX, ninetoothed); +#else CREATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CREATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -92,11 +112,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_NVIDIA, ninetoothed); +#else GET(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -107,8 +135,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_METAX, ninetoothed); +#else GET(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API GET(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -145,11 +177,19 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -160,8 +200,12 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_METAX, ninetoothed); +#else CALCULATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CALCULATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -193,11 +237,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else DELETE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API DELETE(INFINI_DEVICE_QY, nvidia); #endif @@ -208,8 +260,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_METAX, ninetoothed); +#else DELETE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API DELETE(INFINI_DEVICE_CAMBRICON, bang); #endif From 97eced0e3ef3bb0ffd8deb8caaeb46cc80337370 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 06:37:17 +0000 Subject: [PATCH 19/25] issue/923 - ninetoothed kv caching for nv, il, mtx --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/kv_caching.hpp | 16 ++ include/infiniop.h | 1 + include/infiniop/ops/kv_caching.h | 31 ++++ python/infinicore/__init__.py | 2 + python/infinicore/ops/kv_caching.py | 13 ++ src/infinicore/ops/kv_caching/kv_caching.cc | 42 +++++ .../ops/kv_caching/kv_caching_infiniop.cc | 60 ++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/kv_caching.hpp | 32 ++++ .../ops/kv_caching/ninetoothed/build.py | 27 ++++ .../ops/kv_caching/ninetoothed/kv_caching.h | 101 +++++++++++++ .../ops/kv_caching/ninetoothed/kv_caching.py | 66 ++++++++ src/infiniop/ops/kv_caching/operator.cc | 143 ++++++++++++++++++ test/infinicore/framework/base.py | 13 +- test/infinicore/ops/kv_caching.py | 134 ++++++++++++++++ 16 files changed, 681 insertions(+), 3 deletions(-) create mode 100644 include/infinicore/ops/kv_caching.hpp create mode 100644 include/infiniop/ops/kv_caching.h create mode 100644 python/infinicore/ops/kv_caching.py create mode 100644 src/infinicore/ops/kv_caching/kv_caching.cc create mode 100644 src/infinicore/ops/kv_caching/kv_caching_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/kv_caching.hpp create mode 100644 src/infiniop/ops/kv_caching/ninetoothed/build.py create mode 100644 src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h create mode 100644 src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py create mode 100644 src/infiniop/ops/kv_caching/operator.cc create mode 100644 test/infinicore/ops/kv_caching.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 772bc030e..4020bb36e 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -6,6 +6,7 @@ #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" #include "ops/flash_attention.hpp" +#include "ops/kv_caching.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" diff --git a/include/infinicore/ops/kv_caching.hpp b/include/infinicore/ops/kv_caching.hpp new file mode 100644 index 000000000..3a70c2824 --- /dev/null +++ b/include/infinicore/ops/kv_caching.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(KVCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &); + +void kv_caching_(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 092868923..0ea2e2bc0 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -13,6 +13,7 @@ #include "infiniop/ops/flash_attention.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" +#include "infiniop/ops/kv_caching.h" #include "infiniop/ops/layer_norm.h" #include "infiniop/ops/logsoftmax.h" #include "infiniop/ops/lp_norm.h" diff --git a/include/infiniop/ops/kv_caching.h b/include/infiniop/ops/kv_caching.h new file mode 100644 index 000000000..e6efa48b3 --- /dev/null +++ b/include/infiniop/ops/kv_caching.h @@ -0,0 +1,31 @@ +#ifndef __INFINIOP_KV_CACHING_API_H__ +#define __INFINIOP_KV_CACHING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths); + +__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream); + +__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc); + +#endif diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 52a269ce5..49260bbaf 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -45,6 +45,7 @@ from infinicore.ops.add import add from infinicore.ops.add_rms_norm import add_rms_norm from infinicore.ops.attention import attention +from infinicore.ops.kv_caching import kv_caching from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -115,6 +116,7 @@ "add_rms_norm", "add_rms_norm_", "attention", + "kv_caching", "matmul", "mul", "narrow", diff --git a/python/infinicore/ops/kv_caching.py b/python/infinicore/ops/kv_caching.py new file mode 100644 index 000000000..b34f2346e --- /dev/null +++ b/python/infinicore/ops/kv_caching.py @@ -0,0 +1,13 @@ +from infinicore.lib import _infinicore + + +def kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + _infinicore.kv_caching_( + k_cache._underlying, + v_cache._underlying, + k._underlying, + v._underlying, + past_kv_lengths._underlying, + ) + + return k_cache, v_cache diff --git a/src/infinicore/ops/kv_caching/kv_caching.cc b/src/infinicore/ops/kv_caching/kv_caching.cc new file mode 100644 index 000000000..0110f7973 --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching.cc @@ -0,0 +1,42 @@ +#include "infinicore/ops/kv_caching.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(KVCaching); + +KVCaching::KVCaching(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, past_kv_lengths); + INFINICORE_GRAPH_OP_DISPATCH(k_cache->device().getType(), + k_cache, + v_cache, + k, + v, + past_kv_lengths); +} + +void KVCaching::execute(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(KVCaching, + k_cache, + v_cache, + k, + v, + past_kv_lengths); +} + +void kv_caching_(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc new file mode 100644 index 000000000..53ea5f0ae --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc @@ -0,0 +1,60 @@ +#include "../infiniop_impl.hpp" +#include "infinicore/ops/kv_caching.hpp" + +namespace infinicore::op::kv_caching_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, KVCaching, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, k_cache, v_cache, k, v, past_kv_lengths; +}; + +void *plan(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + size_t seed = hash_combine(k_cache, v_cache, k, v, past_kv_lengths); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, KVCaching, + seed, k_cache->desc(), v_cache->desc(), + k->desc(), v->desc(), past_kv_lengths->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, KVCaching, descriptor); + + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(past_kv_lengths)}; + + return planned; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopKVCaching( + planned->descriptor->desc, + nullptr, 0, + planned->k_cache->data(), + planned->v_cache->data(), + planned->k->data(), + planned->v->data(), + planned->past_kv_lengths->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(KVCaching, &plan, &run, cleanup); + +} // namespace infinicore::op::kv_caching_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index c53218990..c7dcae6ca 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -8,6 +8,7 @@ #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" #include "ops/flash_attention.hpp" +#include "ops/kv_caching.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -31,6 +32,7 @@ inline void bind(py::module &m) { bind_attention(m); bind_causal_softmax(m); bind_flash_attention(m); + bind_kv_caching(m); bind_linear(m); bind_matmul(m); bind_mul(m); diff --git a/src/infinicore/pybind11/ops/kv_caching.hpp b/src/infinicore/pybind11/ops/kv_caching.hpp new file mode 100644 index 000000000..2864312b2 --- /dev/null +++ b/src/infinicore/pybind11/ops/kv_caching.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include + +#include "infinicore/ops/kv_caching.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_kv_caching(py::module &m) { + m.def("kv_caching_", + &op::kv_caching_, + py::arg("k_cache"), + py::arg("v_cache"), + py::arg("k"), + py::arg("v"), + py::arg("past_kv_lengths"), + R"doc(In-place Key-Value Caching. + +Updates the KV cache in-place with new key and value tensors. + +Args: + k_cache: Key cache tensor to update in-place + v_cache: Value cache tensor to update in-place + k: New key tensor to append + v: New value tensor to append + past_kv_lengths: Tensor containing current sequence lengths for each batch +)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/kv_caching/ninetoothed/build.py b/src/infiniop/ops/kv_caching/ninetoothed/build.py new file mode 100644 index 000000000..03481c86b --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/build.py @@ -0,0 +1,27 @@ +import ninetoothed +from . import kv_caching + +import infiniop.ninetoothed.build + + +def build(): + dtype_values = ( + ninetoothed.float16, + ninetoothed.bfloat16, + ninetoothed.float32, + ) + + constexpr_param_grid = { + "emb_dim": (1, 16, 32, 64, 128, 256), + "dtype": dtype_values, + "block_size_m": (64,), + "block_size_n": (64,), + } + + infiniop.ninetoothed.build.build( + kv_caching.premake, + constexpr_param_grid, + caller="cuda", + op_name="kv_caching", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h new file mode 100644 index 000000000..43388f58d --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h @@ -0,0 +1,101 @@ +#ifndef KV_CACHING_H +#define KV_CACHING_H + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/kv_caching.h" +#include "../../../ninetoothed/utils.h" + +namespace op::kv_caching::ninetoothed { +class Descriptor final : public InfiniopDescriptor { + +public: + Descriptor( + infiniopHandle_t handle, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id}, + k_cache_shape_{k_cache_desc->shape()}, + k_cache_strides_{k_cache_desc->strides()}, + v_cache_shape_{v_cache_desc->shape()}, + v_cache_strides_{v_cache_desc->strides()}, + k_shape_{k_desc->shape()}, + k_strides_{k_desc->strides()}, + v_shape_{v_desc->shape()}, + v_strides_{v_desc->strides()}, + past_kv_lengths_shape_{past_kv_lengths_desc->shape()}, + past_kv_lengths_strides_{past_kv_lengths_desc->strides()}, + dtype_{k_desc->dtype()} {} + + ~Descriptor() = default; + + size_t get_workspace_size() const { return 0; }; + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + *desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths}; + return INFINI_STATUS_SUCCESS; + } + + infiniStatus_t calculate( + void *workspace, size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream) const { + auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}}; + auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}}; + auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}}; + auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}}; + auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}}; + + if (launch_kv_caching(stream, + k_cache_nt, + v_cache_nt, + k_nt, + v_nt, + past_kv_lengths_nt, + k_shape_[3], + dtype_, + 64, 64)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector k_cache_shape_; + std::vector k_cache_strides_; + + std::vector v_cache_shape_; + std::vector v_cache_strides_; + + std::vector k_shape_; + std::vector k_strides_; + std::vector v_shape_; + std::vector v_strides_; + + std::vector past_kv_lengths_shape_; + std::vector past_kv_lengths_strides_; + + infiniDtype_t dtype_; +}; +} // namespace op::kv_caching::ninetoothed + +#endif // KV_CACHING_H diff --git a/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py new file mode 100644 index 000000000..dfc5088e9 --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py @@ -0,0 +1,66 @@ +import functools +import ninetoothed +from ninetoothed import Tensor + + +def arrangement( + k_cache, + v_cache, + k, + v, + past_lengths, + block_size_m=ninetoothed.block_size(), + block_size_n=ninetoothed.block_size(), +): + k_cache_arranged = k_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + v_cache_arranged = v_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + + k_arranged = k.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + v_arranged = v.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + + past_lengths_arranged = ( + past_lengths.tile((1,)) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .expand((-1, *k_arranged.shape)) + ) + + return ( + k_cache_arranged, + v_cache_arranged, + k_arranged, + v_arranged, + past_lengths_arranged, + ) + + +def application(k_cache, v_cache, k, v, past_lengths): + pos = past_lengths + + for i in range(k.shape[-2]): + k_cache[0, 0, pos + i, 0] = k[0, 0, i, 0] + v_cache[0, 0, pos + i, 0] = v[0, 0, i, 0] + + +def premake(emb_dim=None, dtype=None, block_size_m=None, block_size_n=None): + arrangement_ = functools.partial( + arrangement, block_size_m=block_size_m, block_size_n=block_size_n + ) + + shape_options = (None, None, None, {"constexpr": True, "upper_bound": 256}) + + tensors = ( + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(1, dtype=ninetoothed.int64), + ) + + if emb_dim is not None: + for tensor in tensors: + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/kv_caching/operator.cc b/src/infiniop/ops/kv_caching/operator.cc new file mode 100644 index 000000000..34bdf9a99 --- /dev/null +++ b/src/infiniop/ops/kv_caching/operator.cc @@ -0,0 +1,143 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/kv_caching.h" + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API) +#include "ninetoothed/kv_caching.h" +#endif +#endif + +__C infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::kv_caching::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + k_cache, \ + v_cache, \ + k, \ + v, \ + past_kv_lengths) + + switch (handle->device) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + CREATE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetKVCachingWorkspaceSize( + infiniopKVCachingDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + GET_SIZE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopKVCaching( + infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream) + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + CALCULATE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyKVCachingDescriptor( + infiniopKVCachingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + DELETE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} diff --git a/test/infinicore/framework/base.py b/test/infinicore/framework/base.py index 87222b299..80dcb3eb1 100644 --- a/test/infinicore/framework/base.py +++ b/test/infinicore/framework/base.py @@ -342,7 +342,10 @@ def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target for i, inp in enumerate(inputs): if isinstance(inp, torch.Tensor): # Clone only if this input will be used for comparison - if comparison_target == i: + if comparison_target == i or ( + isinstance(comparison_target, (list, tuple)) + and i in comparison_target + ): cloned_inp = clone_torch_tensor(inp) infini_tensor = infinicore_tensor_from_torch(cloned_inp) cloned_tensors.append(cloned_inp) @@ -508,7 +511,9 @@ def run_test(self, device, test_case, config): # Handle multiple outputs comparison # Determine what to compare based on comparison_target - if comparison_target is None: + if comparison_target is None or isinstance( + comparison_target, (list, tuple) + ): # Compare return values (out-of-place multiple outputs) torch_comparison = torch_result infini_comparison = infini_result @@ -573,7 +578,9 @@ def run_test(self, device, test_case, config): # ========================================================================== else: # Determine comparison targets for single output - if comparison_target is None: + if comparison_target is None or isinstance( + comparison_target, (list, tuple) + ): # Compare return values (out-of-place) torch_comparison = torch_result infini_comparison = infini_result diff --git a/test/infinicore/ops/kv_caching.py b/test/infinicore/ops/kv_caching.py new file mode 100644 index 000000000..4ca857586 --- /dev/null +++ b/test/infinicore/ops/kv_caching.py @@ -0,0 +1,134 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TensorInitializer, + TestCase, + GenericTestRunner, + is_broadcast, +) + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (shape (bs, nkvh, seq_len, hd), strides) +_TEST_CASES_DATA = [ + ((1, 1, 8, 1), None), + ((1, 8, 32, 32), None), + ((8, 8, 64, 32), None), + ((1, 32, 8, 64), (32768, 1024, 64, 1)), + ((4, 8, 32, 16), (65536, 8192, 256, 16)), + ((8, 16, 64, 128), (8388608, 524288, 8192, 1)), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 0}, + infinicore.bfloat16: {"atol": 0, "rtol": 0}, + infinicore.float32: {"atol": 0, "rtol": 0}, +} + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + test_cases = [] + + for data in _TEST_CASES_DATA: + import random + + cache_shape = data[0] + kv_shape = ( + cache_shape[0], + cache_shape[1], + random.randint(1, cache_shape[2]), + cache_shape[3], + ) + past_shape = (cache_shape[0],) + + strides = data[1] + + past_length = random.randint(0, cache_shape[2] - kv_shape[2]) + + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0}) + + cache_spec = TensorSpec.from_tensor(cache_shape, strides, dtype) + kv_spec = TensorSpec.from_tensor(kv_shape, None, dtype) + + past_kv_lengths_spec = TensorSpec.from_tensor( + past_shape, + None, + infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=past_length, + high=past_length + 1, + ) + + test_cases.append( + TestCase( + inputs=[ + cache_spec, + cache_spec, + kv_spec, + kv_spec, + past_kv_lengths_spec, + ], + kwargs={}, + output_spec=None, + comparison_target=[0, 1], + tolerance=tolerance, + description=f"KV Caching", + ) + ) + + return test_cases + + +def torch_kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + batch_size, num_kv_heads, _, head_dim = k_cache.shape + seq_len = k.shape[2] + + for b in range(batch_size): + past_len = past_kv_lengths[b].item() + for h in range(num_kv_heads): + k_cache[b, h, past_len : past_len + seq_len, :] = k[b, h, :, :] + v_cache[b, h, past_len : past_len + seq_len, :] = v[b, h, :, :] + + return k_cache, v_cache + + +def infinicore_kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + infinicore.kv_caching(k_cache, v_cache, k, v, past_kv_lengths) + return k_cache, v_cache + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("KV Caching") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_kv_caching(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore_kv_caching(*args, **kwargs) + + +def main(): + test_runner = GenericTestRunner(OpTest) + test_runner.run_and_exit() + + +if __name__ == "__main__": + main() From 1c18c046d9d4f6f295741da6e2cf7e3080f784af Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 23 Jan 2026 13:22:33 +0000 Subject: [PATCH 20/25] issue/979 optimize paged attention --- .../ops/paged_attention/cuda/kernel_v2.cuh | 2085 +++++++++++++++ src/infiniop/ops/paged_attention/info.h | 149 +- .../nvidia/paged_attention_hd128.cu | 1024 +++++++ .../nvidia/paged_attention_hd64.cu | 524 ++++ .../nvidia/paged_attention_nvidia.cu | 425 ++- .../cuda/kernel_v2.cuh | 2361 +++++++++++++++++ .../ops/paged_attention_prefill/info.h | 166 +- .../nvidia/paged_attention_prefill_nvidia.cu | 1720 +++++++++++- test/infiniop/paged_attention.py | 3 +- test/infiniop/paged_attention_prefill.py | 3 +- 10 files changed, 8209 insertions(+), 251 deletions(-) create mode 100644 src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh create mode 100644 src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu create mode 100644 src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu create mode 100644 src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh diff --git a/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh new file mode 100644 index 000000000..e63dd68e2 --- /dev/null +++ b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh @@ -0,0 +1,2085 @@ +#ifndef __PAGED_ATTENTION_KERNEL_V2_CUH__ +#define __PAGED_ATTENTION_KERNEL_V2_CUH__ + +namespace op::paged_attention::cuda { + +struct OnlineSoftmaxState { + float m = -INFINITY; + float l = 0.0f; + + __device__ __forceinline__ void update(float x, float &alpha, float &beta) { + const float m_new = fmaxf(m, x); + alpha = expf(m - m_new); + beta = expf(x - m_new); + l = l * alpha + beta; + m = m_new; + } +}; +__device__ __forceinline__ float warpReduceSum(float x) { + for (int offset = 16; offset > 0; offset >>= 1) { + x += __shfl_down_sync(0xffffffff, x, offset); + } + return x; +} + +__device__ __forceinline__ float warpReduceMax(float x) { + for (int offset = 16; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_down_sync(0xffffffff, x, offset)); + } + return x; +} + +__device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) { + return static_cast(__cvta_generic_to_shared(ptr)); +} + +__device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const void *src_global) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + const unsigned int dst = cvtaToShared(dst_shared); + asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n" ::"r"(dst), "l"(src_global)); +#else + auto *dst = reinterpret_cast(dst_shared); + const auto *src = reinterpret_cast(src_global); + *dst = *src; +#endif +} + +__device__ __forceinline__ void cpAsyncCommit() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void cpAsyncWaitGroup() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +// cp.async.wait_group requires a compile-time immediate, so for small fixed +// stage counts we provide a tiny runtime switch. +__device__ __forceinline__ void cpAsyncWaitGroupRt(int n) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + if (n <= 0) { + cpAsyncWaitGroup<0>(); + } else if (n == 1) { + cpAsyncWaitGroup<1>(); + } else { + // Clamp to 2 because v0.4 CTA kernel uses STAGES=3. + cpAsyncWaitGroup<2>(); + } +#else + (void)n; +#endif +} + +__device__ __forceinline__ void cpAsyncWaitAll() { + cpAsyncWaitGroup<0>(); +} + +template +__device__ void flashAttentionDecodeWarpKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + // q/out are [num_seqs, num_heads, head_size] + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + const int pbs = static_cast(page_block_size); + + // Iterate by blocks to avoid per-token division/mod and redundant block_table loads. + // Note: Per-token cp.async prefetching is generally too fine-grained for decode and can regress. + // We keep the warp kernel simple and reserve cp.async pipelining for CTA tile kernels. + int t_base = 0; + for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) { + int physical_block = 0; + if (lane == 0) { + physical_block = static_cast(block_table[logical_block]); + } + physical_block = __shfl_sync(0xffffffff, physical_block, 0); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, seq_len - t_base); + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int t = t_base + token_in_block; + const Tdata *k_ptr = k_base + token_in_block * k_row_stride; + const Tdata *v_ptr = v_base + token_in_block * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + + qk = warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float o = acc[i] * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +// Split-KV decode (FA2-style): each split scans a shard of KV and writes partial (m, l, acc) +// to workspace, then a combine kernel merges splits into final out. +template +__device__ void flashAttentionDecodeSplitKvWarpKernel( + float *partial_acc, // [num_splits, num_seqs, num_heads, head_size] + float *partial_m, // [num_splits, num_seqs, num_heads] + float *partial_l, // [num_splits, num_seqs, num_heads] + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int split_idx = static_cast(blockIdx.z); + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0 || num_splits <= 0) { + return; + } + + // Split the [0, seq_len) range into num_splits contiguous shards. + const int shard = (seq_len + num_splits - 1) / num_splits; + const int start = split_idx * shard; + const int end = min(seq_len, start + shard); + if (start >= end) { + // Empty shard => write neutral element. + const int n = gridDim.y * gridDim.x; + const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx); + if (lane == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = 0.0f; + } + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + const int pbs = static_cast(page_block_size); + + // Scan only [start, end). + int t = start; + int logical_block = t / pbs; + int token_in_block = t - logical_block * pbs; + for (; t < end; ++logical_block) { + int physical_block = 0; + if (lane == 0) { + physical_block = static_cast(block_table[logical_block]); + } + physical_block = __shfl_sync(0xffffffff, physical_block, 0); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, end - logical_block * pbs); + for (; token_in_block < token_end && t < end; ++token_in_block, ++t) { + const Tdata *k_ptr = k_base + token_in_block * k_row_stride; + const Tdata *v_ptr = v_base + token_in_block * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + + qk = warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + token_in_block = 0; + } + + const int n = gridDim.y * gridDim.x; + const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx); + if (lane == 0) { + partial_m[idx] = m; + partial_l[idx] = l; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = acc[i]; + } +} + +template +__device__ void flashAttentionDecodeSplitKvCombineWarpKernel( + Tdata *out_, + const float *partial_acc, // [num_splits, num_seqs, num_heads, head_size] + const float *partial_m, // [num_splits, num_seqs, num_heads] + const float *partial_l, // [num_splits, num_seqs, num_heads] + int num_splits, + ptrdiff_t o_stride) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int n = gridDim.y * gridDim.x; + const int base = (seq_idx * gridDim.x + head_idx); + + float m = -INFINITY; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + m = fmaxf(m, partial_m[s * n + base]); + } + } + m = __shfl_sync(0xffffffff, m, 0); + + float l = 0.0f; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[s * n + base]; + const float ls = partial_l[s * n + base]; + if (ls > 0.0f) { + l += ls * exp2f(ms - m); + } + } + } + l = __shfl_sync(0xffffffff, l, 0); + const float inv_l = 1.0f / (l + 1e-6f); + + // Combine acc for each dim. + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + float acc = 0.0f; + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[s * n + base]; + const float w = exp2f(ms - m); + acc += partial_acc[(s * n + base) * HEAD_SIZE + dim] * w; + } + const float o = acc * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +// Split-KV decode with a CTA tile kernel (FA2-style): each CTA scans a shard of KV, +// writes partial (m, l, acc) to workspace, then a combine kernel merges splits. +template +__device__ void flashAttentionDecodeSplitKvCtaKernel( + float *partial_acc, // [num_splits, num_seqs, num_heads, head_size] + float *partial_m, // [num_splits, num_seqs, num_heads] + float *partial_l, // [num_splits, num_seqs, num_heads] + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + + constexpr int kWarpSize = 32; + static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32."); + static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small."); + constexpr int NUM_WARPS = CTA_THREADS / kWarpSize; + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS."); + constexpr int kPack = HEAD_SIZE / CTA_THREADS; // 2 (64@32t, 128@64t) or 4 (128@32t) + static_assert(kPack == 2 || kPack == 4, "v0.4 split-kv CTA kernel supports kPack=2/4 only."); + constexpr int kPackedDims = CTA_THREADS; + constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize; + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int split_idx = static_cast(blockIdx.z); + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0 || num_splits <= 0) { + return; + } + + // Split the [0, seq_len) range into num_splits contiguous shards. + const int shard = (seq_len + num_splits - 1) / num_splits; + const int start = split_idx * shard; + const int end = min(seq_len, start + shard); + + const int n = gridDim.y * gridDim.x; + const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx); + + if (start >= end) { + // Empty shard => write neutral element. + if (tid == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } + const int dim = tid * kPack; + if constexpr (kPack == 2) { + partial_acc[idx * HEAD_SIZE + dim + 0] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 1] = 0.0f; + } else { + partial_acc[idx * HEAD_SIZE + dim + 0] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 1] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 2] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 3] = 0.0f; + } + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + + const int dim = tid * kPack; + float q0 = 0.0f, q1 = 0.0f, q2 = 0.0f, q3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 qh2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __half22float2(qh2); + q0 = qf.x; + q1 = qf.y; + } else { + const half2 qh2_0 = *reinterpret_cast(q_ptr + dim + 0); + const half2 qh2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __half22float2(qh2_0); + const float2 qf1 = __half22float2(qh2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 qb2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __bfloat1622float2(qb2); + q0 = qf.x; + q1 = qf.y; + } else { + const __nv_bfloat162 qb2_0 = *reinterpret_cast(q_ptr + dim + 0); + const __nv_bfloat162 qb2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __bfloat1622float2(qb2_0); + const float2 qf1 = __bfloat1622float2(qb2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else +#endif + { + q0 = static_cast(q_ptr[dim + 0]); + q1 = static_cast(q_ptr[dim + 1]); + if constexpr (kPack == 4) { + q2 = static_cast(q_ptr[dim + 2]); + q3 = static_cast(q_ptr[dim + 3]); + } + } + + float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f; + + float m = -INFINITY; + float l = 0.0f; + + __shared__ float warp_sums[TOKENS_PER_TILE][kComputeWarps]; + __shared__ float alpha_shared; + __shared__ float weights_shared[TOKENS_PER_TILE]; + + const int pbs = static_cast(page_block_size); + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + static_assert(sizeof(Tdata) == 2, "CTA split-kv kernel assumes fp16/bf16."); + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + constexpr int STAGES = 3; + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + + const int first_block = start / pbs; + const int last_block = (end - 1) / pbs; + + for (int logical_block = first_block; logical_block <= last_block; ++logical_block) { + const int physical_block = static_cast(block_table[logical_block]); + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int t_base = logical_block * pbs; + const int token_begin = (logical_block == first_block) ? (start - t_base) : 0; + const int token_end = (logical_block == last_block) ? (end - t_base) : pbs; + const int token_count = token_end - token_begin; + if (token_count <= 0) { + continue; + } + + const int num_tiles = (token_count + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = token_begin + ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = token_begin + tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + float partial[TOKENS_PER_TILE]; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + if (j < tile_n) { + float k0 = 0.0f, k1 = 0.0f, k2 = 0.0f, k3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 kh2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __half22float2(kh2); + k0 = kf.x; + k1 = kf.y; + } else { + const half2 kh2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const half2 kh2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __half22float2(kh2_0); + const float2 kf1 = __half22float2(kh2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 kb2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __bfloat1622float2(kb2); + k0 = kf.x; + k1 = kf.y; + } else { + const __nv_bfloat162 kb2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const __nv_bfloat162 kb2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __bfloat1622float2(kb2_0); + const float2 kf1 = __bfloat1622float2(kb2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else +#endif + { + k0 = static_cast(sh_k[buf][j][dim + 0]); + k1 = static_cast(sh_k[buf][j][dim + 1]); + if constexpr (kPack == 4) { + k2 = static_cast(sh_k[buf][j][dim + 2]); + k3 = static_cast(sh_k[buf][j][dim + 3]); + } + } + if constexpr (kPack == 2) { + partial[j] = fmaf(q0, k0, q1 * k1); + } else { + partial[j] = fmaf(q0, k0, fmaf(q1, k1, fmaf(q2, k2, q3 * k3))); + } + } else { + partial[j] = 0.0f; + } + } + +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float sum = warpReduceSum(partial[j]); + if (lane == 0 && warp_id < kComputeWarps) { + warp_sums[j][warp_id] = sum; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + if (warp_id == 0) { + float score = -INFINITY; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + float qk = 0.0f; +#pragma unroll + for (int w = 0; w < kComputeWarps; ++w) { + qk += warp_sums[lane][w]; + } + const int t = t_base + token_in_block + lane; + score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + } + + float tile_max = warpReduceMax(score); + tile_max = __shfl_sync(0xffffffff, tile_max, 0); + + float m_new = 0.0f; + if (lane == 0) { + m_new = fmaxf(m, tile_max); + } + m_new = __shfl_sync(0xffffffff, m_new, 0); + + float w = 0.0f; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + w = exp2f(score - m_new); + } + if (lane < TOKENS_PER_TILE) { + weights_shared[lane] = (lane < tile_n) ? w : 0.0f; + } + + const float tile_sum = warpReduceSum(w); + if (lane == 0) { + const float alpha = exp2f(m - m_new); + alpha_shared = alpha; + l = l * alpha + tile_sum; + m = m_new; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + const float alpha = alpha_shared; + float sum_wv0 = 0.0f, sum_wv1 = 0.0f, sum_wv2 = 0.0f, sum_wv3 = 0.0f; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float w = weights_shared[j]; + float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 vh2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __half22float2(vh2); + v0 = vf.x; + v1 = vf.y; + } else { + const half2 vh2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const half2 vh2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __half22float2(vh2_0); + const float2 vf1 = __half22float2(vh2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 vb2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __bfloat1622float2(vb2); + v0 = vf.x; + v1 = vf.y; + } else { + const __nv_bfloat162 vb2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const __nv_bfloat162 vb2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __bfloat1622float2(vb2_0); + const float2 vf1 = __bfloat1622float2(vb2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else +#endif + { + v0 = static_cast(sh_v[buf][j][dim + 0]); + v1 = static_cast(sh_v[buf][j][dim + 1]); + if constexpr (kPack == 4) { + v2 = static_cast(sh_v[buf][j][dim + 2]); + v3 = static_cast(sh_v[buf][j][dim + 3]); + } + } + sum_wv0 = fmaf(w, v0, sum_wv0); + sum_wv1 = fmaf(w, v1, sum_wv1); + if constexpr (kPack == 4) { + sum_wv2 = fmaf(w, v2, sum_wv2); + sum_wv3 = fmaf(w, v3, sum_wv3); + } + } + acc0 = acc0 * alpha + sum_wv0; + acc1 = acc1 * alpha + sum_wv1; + if constexpr (kPack == 4) { + acc2 = acc2 * alpha + sum_wv2; + acc3 = acc3 * alpha + sum_wv3; + } + + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = token_begin + prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + } + + cpAsyncWaitAll(); + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + + if (tid == 0) { + partial_m[idx] = m; + partial_l[idx] = l; + } + if constexpr (kPack == 2) { + partial_acc[idx * HEAD_SIZE + dim + 0] = acc0; + partial_acc[idx * HEAD_SIZE + dim + 1] = acc1; + } else { + partial_acc[idx * HEAD_SIZE + dim + 0] = acc0; + partial_acc[idx * HEAD_SIZE + dim + 1] = acc1; + partial_acc[idx * HEAD_SIZE + dim + 2] = acc2; + partial_acc[idx * HEAD_SIZE + dim + 3] = acc3; + } +} + +template +__device__ void flashAttentionDecodeCtaPipelinedKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int NUM_WARPS = HEAD_SIZE / kWarpSize; + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; + + const float q_val = static_cast(q_ptr[tid]); + float acc = 0.0f; + + float m = -INFINITY; + float l = 0.0f; + + __shared__ Tdata sh_k[2][HEAD_SIZE]; + __shared__ Tdata sh_v[2][HEAD_SIZE]; + __shared__ float warp_sums[NUM_WARPS]; + __shared__ float alpha_s; + __shared__ float beta_s; + __shared__ int physical_block_s; + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + + const int pbs = static_cast(page_block_size); + + // Prefetch the very first token. + int buf = 0; + int t_base = 0; + int token_in_block = 0; + int logical_block = 0; + { + if (tid == 0) { + physical_block_s = static_cast(block_table[0]); + } + __syncthreads(); + const Tdata *k_base = k_cache_ + physical_block_s * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block_s * v_batch_stride + kv_head_idx * v_head_stride; + if (tid < CHUNKS) { + const int off = tid * CHUNK_ELEMS; + cpAsyncCaSharedGlobal16(&sh_k[buf][off], (k_base + 0 * k_row_stride) + off); + cpAsyncCaSharedGlobal16(&sh_v[buf][off], (v_base + 0 * v_row_stride) + off); + } + cpAsyncCommit(); + cpAsyncWaitAll(); + __syncthreads(); + } + + for (int t = 0; t < seq_len; ++t) { + // Compute current token location within paged KV. + const int next_t = t + 1; + const bool has_next = next_t < seq_len; + + if (has_next) { + const int next_block = next_t / pbs; + const int next_in_block = next_t - next_block * pbs; + if (next_block != logical_block) { + logical_block = next_block; + if (tid == 0) { + physical_block_s = static_cast(block_table[logical_block]); + } + __syncthreads(); + } + + const Tdata *k_base = k_cache_ + physical_block_s * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block_s * v_batch_stride + kv_head_idx * v_head_stride; + const Tdata *k_src = k_base + next_in_block * k_row_stride; + const Tdata *v_src = v_base + next_in_block * v_row_stride; + if (tid < CHUNKS) { + const int off = tid * CHUNK_ELEMS; + cpAsyncCaSharedGlobal16(&sh_k[buf ^ 1][off], k_src + off); + cpAsyncCaSharedGlobal16(&sh_v[buf ^ 1][off], v_src + off); + } + cpAsyncCommit(); + } + + // Dot: each thread handles one dim, reduce across head dim. + const float k_val = static_cast(sh_k[buf][tid]); + float partial = q_val * k_val; + float warp_sum = warpReduceSum(partial); + if (lane == 0) { + warp_sums[warp_id] = warp_sum; + } + __syncthreads(); + + float qk = 0.0f; + if (warp_id == 0) { + float v = (lane < NUM_WARPS) ? warp_sums[lane] : 0.0f; + v = warpReduceSum(v); + if (lane == 0) { + qk = v; + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + const float alpha = exp2f(m - m_new); + const float beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + alpha_s = alpha; + beta_s = beta; + } + } + __syncthreads(); + + const float alpha = alpha_s; + const float beta = beta_s; + const float v_val = static_cast(sh_v[buf][tid]); + acc = acc * alpha + beta * v_val; + + if (has_next) { + cpAsyncWaitAll(); + __syncthreads(); + buf ^= 1; + } + } + + __shared__ float inv_l_s; + if (tid == 0) { + inv_l_s = 1.0f / (l + 1e-6f); + } + __syncthreads(); + out_ptr[tid] = static_cast(acc * inv_l_s); +} + +template +__device__ void flashAttentionDecodeCtaKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + constexpr int kWarpSize = 32; + static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32."); + static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small."); + constexpr int NUM_WARPS = CTA_THREADS / kWarpSize; + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + + // Each thread owns a small packed vector of head dims. This lets us shrink the + // CTA to 1-2 warps and reduce block-wide synchronization overhead. + static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS."); + constexpr int kPack = HEAD_SIZE / CTA_THREADS; // 2 (64@32t, 128@64t) or 4 (128@32t) + static_assert(kPack == 2 || kPack == 4, "v0.4 CTA tile kernel supports kPack=2/4 only."); + constexpr int kPackedDims = CTA_THREADS; + constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize; + const int dim = tid * kPack; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + // q/out are [num_seqs, num_heads, head_size] + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; + + float q0 = 0.0f; + float q1 = 0.0f; + float q2 = 0.0f; + float q3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 qh2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __half22float2(qh2); + q0 = qf.x; + q1 = qf.y; + } else { + const half2 qh2_0 = *reinterpret_cast(q_ptr + dim + 0); + const half2 qh2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __half22float2(qh2_0); + const float2 qf1 = __half22float2(qh2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 qb2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __bfloat1622float2(qb2); + q0 = qf.x; + q1 = qf.y; + } else { + const __nv_bfloat162 qb2_0 = *reinterpret_cast(q_ptr + dim + 0); + const __nv_bfloat162 qb2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __bfloat1622float2(qb2_0); + const float2 qf1 = __bfloat1622float2(qb2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else +#endif + { + q0 = static_cast(q_ptr[dim + 0]); + q1 = static_cast(q_ptr[dim + 1]); + if constexpr (kPack == 4) { + q2 = static_cast(q_ptr[dim + 2]); + q3 = static_cast(q_ptr[dim + 3]); + } + } + + float acc0 = 0.0f; + float acc1 = 0.0f; + float acc2 = 0.0f; + float acc3 = 0.0f; + + float m = -INFINITY; + float l = 0.0f; + + // Only the compute warps contribute QK partial sums. Keeping this array + // compact reduces shared-memory traffic and bank pressure. + __shared__ float warp_sums[TOKENS_PER_TILE][kComputeWarps]; + __shared__ float alpha_shared; + __shared__ float weights_shared[TOKENS_PER_TILE]; + + const int pbs = static_cast(page_block_size); + + static_assert(sizeof(Tdata) == 2, "CTA tile kernel assumes 16B chunks map to 8 elements for fp16/bf16."); + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + // Multi-stage cp.async pipeline. Using >= 3 stages allows us to keep + // multiple groups in-flight and overlap global->shared copies with compute. + constexpr int STAGES = 3; + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int t_base = 0; + for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, seq_len - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + // Ensure tile 0 is ready. We want to keep up to (STAGES - 1) groups + // in flight for overlap, but still make forward progress in the tail + // when we stop issuing new prefetch groups. + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + float partial[TOKENS_PER_TILE]; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + if (j < tile_n) { + float k0 = 0.0f; + float k1 = 0.0f; + float k2 = 0.0f; + float k3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 kh2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __half22float2(kh2); + k0 = kf.x; + k1 = kf.y; + } else { + const half2 kh2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const half2 kh2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __half22float2(kh2_0); + const float2 kf1 = __half22float2(kh2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 kb2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __bfloat1622float2(kb2); + k0 = kf.x; + k1 = kf.y; + } else { + const __nv_bfloat162 kb2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const __nv_bfloat162 kb2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __bfloat1622float2(kb2_0); + const float2 kf1 = __bfloat1622float2(kb2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else +#endif + { + k0 = static_cast(sh_k[buf][j][dim + 0]); + k1 = static_cast(sh_k[buf][j][dim + 1]); + if constexpr (kPack == 4) { + k2 = static_cast(sh_k[buf][j][dim + 2]); + k3 = static_cast(sh_k[buf][j][dim + 3]); + } + } + if constexpr (kPack == 2) { + partial[j] = fmaf(q0, k0, q1 * k1); + } else { + partial[j] = fmaf(q0, k0, fmaf(q1, k1, fmaf(q2, k2, q3 * k3))); + } + } else { + partial[j] = 0.0f; + } + } + +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + float sum = warpReduceSum(partial[j]); + // Only compute warps contribute to qk; load-only warps would + // otherwise write zeros and increase reduction overhead. + if (lane == 0 && warp_id < kComputeWarps) { + warp_sums[j][warp_id] = sum; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + if (warp_id == 0) { + // Distribute token-wise score computation across lanes to avoid + // serial loops in lane0. TOKENS_PER_TILE <= 16 by construction. + float score = -INFINITY; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + float qk = 0.0f; +#pragma unroll + for (int w = 0; w < kComputeWarps; ++w) { + qk += warp_sums[lane][w]; + } + const int t = t_base + token_in_block + lane; + score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + } + + float tile_max = warpReduceMax(score); + tile_max = __shfl_sync(0xffffffff, tile_max, 0); + + float m_new = 0.0f; + if (lane == 0) { + m_new = fmaxf(m, tile_max); + } + m_new = __shfl_sync(0xffffffff, m_new, 0); + + float w = 0.0f; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + w = exp2f(score - m_new); + } + + if (lane < TOKENS_PER_TILE) { + weights_shared[lane] = (lane < tile_n) ? w : 0.0f; + } + + float tile_sum = warpReduceSum(w); + if (lane == 0) { + const float alpha = exp2f(m - m_new); + alpha_shared = alpha; + l = l * alpha + tile_sum; + m = m_new; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + const float alpha = alpha_shared; + float sum_wv0 = 0.0f; + float sum_wv1 = 0.0f; + float sum_wv2 = 0.0f; + float sum_wv3 = 0.0f; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float w = weights_shared[j]; + float v0 = 0.0f; + float v1 = 0.0f; + float v2 = 0.0f; + float v3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 vh2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __half22float2(vh2); + v0 = vf.x; + v1 = vf.y; + } else { + const half2 vh2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const half2 vh2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __half22float2(vh2_0); + const float2 vf1 = __half22float2(vh2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 vb2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __bfloat1622float2(vb2); + v0 = vf.x; + v1 = vf.y; + } else { + const __nv_bfloat162 vb2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const __nv_bfloat162 vb2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __bfloat1622float2(vb2_0); + const float2 vf1 = __bfloat1622float2(vb2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else +#endif + { + v0 = static_cast(sh_v[buf][j][dim + 0]); + v1 = static_cast(sh_v[buf][j][dim + 1]); + if constexpr (kPack == 4) { + v2 = static_cast(sh_v[buf][j][dim + 2]); + v3 = static_cast(sh_v[buf][j][dim + 3]); + } + } + sum_wv0 = fmaf(w, v0, sum_wv0); + sum_wv1 = fmaf(w, v1, sum_wv1); + if constexpr (kPack == 4) { + sum_wv2 = fmaf(w, v2, sum_wv2); + sum_wv3 = fmaf(w, v3, sum_wv3); + } + } + acc0 = acc0 * alpha + sum_wv0; + acc1 = acc1 * alpha + sum_wv1; + if constexpr (kPack == 4) { + acc2 = acc2 * alpha + sum_wv2; + acc3 = acc3 * alpha + sum_wv3; + } + + // Prefetch the tile that will reuse this buffer (STAGES steps ahead). + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + // Before consuming the next tile, ensure at least one group + // completes. In steady state we keep (STAGES - 1) in flight; in + // the tail (no more prefetches) we gradually drain. + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + } + + // Drain any in-flight async copies before moving to the next paged block. + cpAsyncWaitAll(); + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + + __shared__ float inv_l_shared; + if (tid == 0) { + inv_l_shared = 1.0f / (l + 1e-6f); + } + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + { + const float s = inv_l_shared; + const float o0 = acc0 * s; + const float o1 = acc1 * s; + const float o2 = acc2 * s; + const float o3 = acc3 * s; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2half_rn(o0); + out_ptr[dim + 1] = __float2half_rn(o1); + if constexpr (kPack == 4) { + out_ptr[dim + 2] = __float2half_rn(o2); + out_ptr[dim + 3] = __float2half_rn(o3); + } + } else if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2bfloat16_rn(o0); + out_ptr[dim + 1] = __float2bfloat16_rn(o1); + if constexpr (kPack == 4) { + out_ptr[dim + 2] = __float2bfloat16_rn(o2); + out_ptr[dim + 3] = __float2bfloat16_rn(o3); + } + } else +#endif + { + out_ptr[dim + 0] = static_cast(o0); + out_ptr[dim + 1] = static_cast(o1); + if constexpr (kPack == 4) { + out_ptr[dim + 2] = static_cast(o2); + out_ptr[dim + 3] = static_cast(o3); + } + } + } +} + +// GQA/MQA fused decode kernel: one CTA computes outputs for NGROUPS query heads that +// share the same KV head. This reduces redundant K/V reads when num_heads > num_kv_heads. +// +// v0.4: implemented for head_dim=128 and NGROUPS=4 (common case: 32 Q heads / 8 KV heads). +template +__device__ void flashAttentionDecodeCtaGqaKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 128, "v0.4 GQA fused CTA kernel is implemented for head_size=128 only."); + static_assert(NGROUPS == 4, "v0.4 GQA fused CTA kernel is implemented for NGROUPS=4 only."); + static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32."); + static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small."); + constexpr int NUM_WARPS = CTA_THREADS / kWarpSize; + + // Pack dims per thread. For head_dim=128 and CTA_THREADS=64, kPack=2. + static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS."); + constexpr int kPack = HEAD_SIZE / CTA_THREADS; + static_assert(kPack == 2, "v0.4 GQA fused CTA kernel expects kPack=2."); + constexpr int kPackedDims = CTA_THREADS; + constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize; + + const int seq_idx = blockIdx.y; + const int kv_head_idx = blockIdx.x; + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + const int dim = tid * kPack; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + // v0.4 limitation: alibi slopes are per query head; support can be added later. + if (alibi_slopes_ != nullptr) { + return; + } + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + // q/out are [num_seqs, num_heads, head_size]. For a KV head, we handle NGROUPS query heads: + // q_head = kv_head * NGROUPS + g + float q0[NGROUPS]; + float q1[NGROUPS]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE; + const half2 qh2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __half22float2(qh2); + q0[g] = qf.x; + q1[g] = qf.y; + } + } else if constexpr (std::is_same_v) { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE; + const __nv_bfloat162 qb2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __bfloat1622float2(qb2); + q0[g] = qf.x; + q1[g] = qf.y; + } + } else +#endif + { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE; + q0[g] = static_cast(q_ptr[dim + 0]); + q1[g] = static_cast(q_ptr[dim + 1]); + } + } + + float acc0[NGROUPS]; + float acc1[NGROUPS]; + float m[NGROUPS]; + float l[NGROUPS]; +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + acc0[g] = 0.0f; + acc1[g] = 0.0f; + m[g] = -INFINITY; + l[g] = 0.0f; + } + + __shared__ float warp_sums[NGROUPS][TOKENS_PER_TILE][kComputeWarps]; + __shared__ float alpha_shared[NGROUPS]; + __shared__ float weights_shared[NGROUPS][TOKENS_PER_TILE]; + + const int pbs = static_cast(page_block_size); + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + static_assert(sizeof(Tdata) == 2, "CTA GQA kernel assumes fp16/bf16."); + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + constexpr int STAGES = 3; + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + + int t_base = 0; + for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, seq_len - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + // Compute QK partial sums for each group and each token in the tile. + float partial_qk[NGROUPS][TOKENS_PER_TILE]; +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + if (j < tile_n) { + float k0 = 0.0f; + float k1 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const half2 kh2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __half22float2(kh2); + k0 = kf.x; + k1 = kf.y; + } else if constexpr (std::is_same_v) { + const __nv_bfloat162 kb2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __bfloat1622float2(kb2); + k0 = kf.x; + k1 = kf.y; + } else +#endif + { + k0 = static_cast(sh_k[buf][j][dim + 0]); + k1 = static_cast(sh_k[buf][j][dim + 1]); + } + partial_qk[g][j] = fmaf(q0[g], k0, q1[g] * k1); + } else { + partial_qk[g][j] = 0.0f; + } + } + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float sum = warpReduceSum(partial_qk[g][j]); + if (lane == 0 && warp_id < kComputeWarps) { + warp_sums[g][j][warp_id] = sum; + } + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + if (warp_id == 0) { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + float score = -INFINITY; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + float qk = 0.0f; +#pragma unroll + for (int w = 0; w < kComputeWarps; ++w) { + qk += warp_sums[g][lane][w]; + } + score = qk * scale_log2; + } + + float tile_max = warpReduceMax(score); + tile_max = __shfl_sync(0xffffffff, tile_max, 0); + + float m_new = 0.0f; + if (lane == 0) { + m_new = fmaxf(m[g], tile_max); + } + m_new = __shfl_sync(0xffffffff, m_new, 0); + + float w = 0.0f; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + w = exp2f(score - m_new); + } + if (lane < TOKENS_PER_TILE) { + weights_shared[g][lane] = (lane < tile_n) ? w : 0.0f; + } + + const float tile_sum = warpReduceSum(w); + if (lane == 0) { + const float alpha = exp2f(m[g] - m_new); + alpha_shared[g] = alpha; + l[g] = l[g] * alpha + tile_sum; + m[g] = m_new; + } + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + float alpha[NGROUPS]; + float sum_wv0[NGROUPS]; + float sum_wv1[NGROUPS]; +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + alpha[g] = alpha_shared[g]; + sum_wv0[g] = 0.0f; + sum_wv1[g] = 0.0f; + } + +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + float v0 = 0.0f; + float v1 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const half2 vh2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __half22float2(vh2); + v0 = vf.x; + v1 = vf.y; + } else if constexpr (std::is_same_v) { + const __nv_bfloat162 vb2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __bfloat1622float2(vb2); + v0 = vf.x; + v1 = vf.y; + } else +#endif + { + v0 = static_cast(sh_v[buf][j][dim + 0]); + v1 = static_cast(sh_v[buf][j][dim + 1]); + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const float w = weights_shared[g][j]; + sum_wv0[g] = fmaf(w, v0, sum_wv0[g]); + sum_wv1[g] = fmaf(w, v1, sum_wv1[g]); + } + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + acc0[g] = acc0[g] * alpha[g] + sum_wv0[g]; + acc1[g] = acc1[g] * alpha[g] + sum_wv1[g]; + } + + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + } + + cpAsyncWaitAll(); + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + + // Write outputs for each group. + __shared__ float inv_l_shared[NGROUPS]; + if (tid < NGROUPS) { + inv_l_shared[tid] = 1.0f / (l[tid] + 1e-6f); + } + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + Tdata *out_ptr = out_ + seq_idx * o_stride + q_head * HEAD_SIZE; + const float s = inv_l_shared[g]; + const float o0 = acc0[g] * s; + const float o1 = acc1[g] * s; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2half_rn(o0); + out_ptr[dim + 1] = __float2half_rn(o1); + } else if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2bfloat16_rn(o0); + out_ptr[dim + 1] = __float2bfloat16_rn(o1); + } else +#endif + { + out_ptr[dim + 0] = static_cast(o0); + out_ptr[dim + 1] = static_cast(o1); + } + } +} +} // namespace op::paged_attention::cuda + +#endif // __PAGED_ATTENTION_KERNEL_V2_CUH__ diff --git a/src/infiniop/ops/paged_attention/info.h b/src/infiniop/ops/paged_attention/info.h index 216bb2360..4b840af69 100644 --- a/src/infiniop/ops/paged_attention/info.h +++ b/src/infiniop/ops/paged_attention/info.h @@ -13,92 +13,171 @@ class PagedAttentionInfo { PagedAttentionInfo() = default; public: - // --- Data Types and Scale --- infiniDtype_t dtype; + infiniDtype_t index_dtype; float scale; - // --- Shape Dimensions --- size_t num_seqs; size_t num_heads; size_t num_kv_heads; size_t head_size; - size_t block_size; + size_t page_block_size; size_t max_num_blocks_per_seq; - // --- Strides for Memory Layout --- ptrdiff_t q_stride; - ptrdiff_t kv_block_stride; - ptrdiff_t kv_head_stride; + ptrdiff_t k_batch_stride; + ptrdiff_t k_row_stride; + ptrdiff_t k_head_stride; + ptrdiff_t v_batch_stride; + ptrdiff_t v_row_stride; + ptrdiff_t v_head_stride; ptrdiff_t o_stride; + ptrdiff_t block_table_batch_stride; + ptrdiff_t cache_lens_stride; + static utils::Result create( infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t cache_lens_desc, const std::optional &alibi_slopes_desc, float scale) { auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) { + if (q_desc->ndim() != 3 || out_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (block_tables_desc->ndim() != 2 || cache_lens_desc->ndim() != 1) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (block_tables_desc->dtype() != INFINI_DTYPE_I64) { + CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + const auto block_tables_dt = block_tables_desc->dtype(); + const auto cache_lens_dt = cache_lens_desc->dtype(); + const bool debug_dtype = (std::getenv("INFINIOP_FLASH_DEBUG_DTYPE") != nullptr); + const bool block_tables_ok = (block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32); + const bool cache_lens_ok = (cache_lens_dt == INFINI_DTYPE_I64) || (cache_lens_dt == INFINI_DTYPE_I32) || (cache_lens_dt == INFINI_DTYPE_U32); + if (!(block_tables_ok && cache_lens_ok)) { + if (debug_dtype) { + std::fprintf(stderr, + "[flash_attention] Bad index dtype: block_tables=%d cache_lens=%d (expected I32/I64/U32)\n", + static_cast(block_tables_dt), static_cast(cache_lens_dt)); + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } - - if (seq_lens_desc->dtype() != INFINI_DTYPE_I64) { + if (block_tables_dt != cache_lens_dt) { + // Keep them consistent to simplify backend dispatch. + if (debug_dtype) { + std::fprintf(stderr, + "[flash_attention] Mismatched index dtype: block_tables=%d cache_lens=%d\n", + static_cast(block_tables_dt), static_cast(cache_lens_dt)); + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } - // --- Extract shape dimensions --- + CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(cache_lens_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) { + if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (alibi_slopes_desc.value()->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + } + + // Shapes auto q_shape = q_desc->shape(); - auto k_cache_shape = k_cache_desc->shape(); + auto k_shape = k_cache_desc->shape(); + + const size_t num_seqs = q_shape[0]; + const size_t num_heads = q_shape[1]; + const size_t head_size = q_shape[2]; + + const size_t num_blocks = k_shape[0]; + (void)num_blocks; + const size_t page_block_size = k_shape[2]; + const size_t num_kv_heads = k_shape[1]; + + // if (page_block_size % 256 != 0) { + // printf("paged block size %zu\n", page_block_size); + // return INFINI_STATUS_BAD_TENSOR_SHAPE; + // } + if (head_size != 64 && head_size != 128) { + // First build only targets common FA2 head dims (expand later). + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_heads % num_kv_heads != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[1] != k_shape[1] || v_cache_desc->shape()[2] != k_shape[2] || v_cache_desc->shape()[3] != k_shape[3]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } - size_t num_seqs = q_shape[0]; - size_t num_heads = q_shape[1]; - size_t head_size = q_shape[2]; + if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } - if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) { - std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got " - << head_size << "." << std::endl; + if (cache_lens_desc->shape()[0] != num_seqs) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - size_t num_kv_heads = k_cache_shape[1]; - size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠 - size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + + // Strides (in elements) + const ptrdiff_t q_stride = q_desc->stride(0); + const ptrdiff_t o_stride = out_desc->stride(0); + + const ptrdiff_t k_batch_stride = k_cache_desc->stride(0); + const ptrdiff_t k_row_stride = k_cache_desc->stride(2); + const ptrdiff_t k_head_stride = k_cache_desc->stride(1); + + const ptrdiff_t v_batch_stride = v_cache_desc->stride(0); + const ptrdiff_t v_row_stride = v_cache_desc->stride(2); + const ptrdiff_t v_head_stride = v_cache_desc->stride(1); - // --- Calculate max_seq_len for shared memory allocation --- - // This is a safe upper bound. - // info.max_seq_len = info.max_num_blocks_per_seq * info.block_size; - // --- Extract strides for memory access --- - ptrdiff_t q_stride = q_desc->stride(0); - ptrdiff_t kv_block_stride = k_cache_desc->stride(0); - ptrdiff_t kv_head_stride = k_cache_desc->stride(1); - ptrdiff_t o_stride = out_desc->stride(0); + const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0); + const ptrdiff_t cache_lens_stride = cache_lens_desc->stride(0); return utils::Result(PagedAttentionInfo{ dtype, + block_tables_dt, scale, num_seqs, num_heads, num_kv_heads, head_size, - block_size, + page_block_size, max_num_blocks_per_seq, q_stride, - kv_block_stride, - kv_head_stride, - o_stride}); + k_batch_stride, + k_row_stride, + k_head_stride, + v_batch_stride, + v_row_stride, + v_head_stride, + o_stride, + block_table_batch_stride, + cache_lens_stride, + }); } }; diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu new file mode 100644 index 000000000..c16b48e48 --- /dev/null +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu @@ -0,0 +1,1024 @@ +#include + +#include +#include +#include +#include + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention::nvidia { + +namespace { +constexpr int kMaxSplits = 8; + +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline int getSmCount() { + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { + return 0; + } + int sm_count = 0; + if (cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device) != cudaSuccess) { + return 0; + } + return sm_count; +} + +// A lightweight FA2-style "waves" heuristic. +// +// Important: our split-kv kernel shards the KV sequence length, so the main "work" +// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k +// (max pages * page size), which matches common decode microbench where all seqs +// share the same cache length. +inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) { + if (sm_count <= 0) { + return 1; + } + if (num_heads == 0 || num_seqs == 0) { + return 1; + } + if (seqlen_k <= 256) { + return 1; + } + + const size_t base_blocks = num_heads * num_seqs; + int best_splits = 1; + // Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens. + size_t best_score = (ceilDiv(base_blocks, static_cast(sm_count)) * seqlen_k); + + size_t prev_work_per_block = seqlen_k; + for (int s = 2; s <= kMaxSplits; ++s) { + const size_t blocks = base_blocks * static_cast(s); + const size_t waves_split = ceilDiv(blocks, static_cast(sm_count)); + const size_t work_per_block = ceilDiv(seqlen_k, static_cast(s)); + // If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant. + if (work_per_block == prev_work_per_block) { + continue; + } + prev_work_per_block = work_per_block; + // Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit. + const size_t waves_combine = ceilDiv(base_blocks, static_cast(sm_count)); + const size_t score = waves_split * work_per_block + waves_combine; + if (score < best_score) { + best_score = score; + best_splits = s; + } + } + return best_splits; +} +} // namespace + +inline bool envBool(const char *name) { + if (const char *env = std::getenv(name)) { + return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + return false; +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeWarpKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Default CTA variant (lower overhead). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128CtaTile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta32( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Experimental 1-warp CTA variant for head_dim=128 (kPack=4). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta32Tile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128CtaGqa4( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // GQA fused kernel: CTA computes 4 query heads for one KV head (head_dim=128). + op::paged_attention::cuda::flashAttentionDecodeCtaGqaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCtaTile16( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta32( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta32Tile16( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, o_stride); +} + +template +infiniStatus_t launch_decode_hd128_impl( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + + // Default decode config (2026-01-22): + // decode_flash_cta8_64_gqa_splitkv_4 + // Users can override any knob via the corresponding INFINIOP_FLASH_* env vars. + bool use_cta = true; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) { + // Backward-compatible: any non-"cta" value means "warp". + use_cta = (std::strcmp(env, "cta") == 0); + } + bool use_gqa_fused = true; + if (const char *env = std::getenv("INFINIOP_FLASH_GQA_FUSED")) { + if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) { + use_gqa_fused = false; + } else { + use_gqa_fused = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + } + int cta_tile = 8; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) { + const int v = std::atoi(env); + if (v == 8 || v == 16) { + cta_tile = v; + } + } + int cta_threads = 64; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_THREADS")) { + const int v = std::atoi(env); + if (v == 32 || v == 64) { + cta_threads = v; + } + } + dim3 block(use_cta ? static_cast(cta_threads) : 32); + + bool use_split = true; + bool use_split_auto = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + if (std::strcmp(env, "auto") == 0) { + use_split_auto = true; + use_split = false; + } else { + if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) { + use_split = false; + } else { + use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + } + } + int num_splits = 4; + bool fixed_num_splits = true; + if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) { + if (std::strcmp(env, "auto") == 0) { + fixed_num_splits = false; + } else { + num_splits = std::atoi(env); + fixed_num_splits = (num_splits > 0); + } + } + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + const bool debug_dispatch = envBool("INFINIOP_FLASH_DEBUG_DISPATCH"); + auto dump_dispatch = [&](const char *path) { + if (!debug_dispatch) { + return; + } + // Avoid spamming: only print when the key dispatch signature changes. + struct Sig { + const char *path; + int dtype; + size_t heads; + size_t kv_heads; + size_t seqs; + size_t pbs; + size_t max_blocks; + int cta_tile; + int cta_threads; + int split; + int split_auto; + int num_splits; + int fixed; + int gqa_fused; + }; + static Sig last{}; + static bool has_last = false; + + Sig cur{ + path, + static_cast(dtype), + num_heads, + num_kv_heads, + num_seqs, + page_block_size, + max_num_blocks_per_seq, + cta_tile, + cta_threads, + static_cast(use_split), + static_cast(use_split_auto), + num_splits, + static_cast(fixed_num_splits), + static_cast(use_gqa_fused), + }; + + if (has_last && cur.path == last.path && cur.dtype == last.dtype && cur.heads == last.heads && cur.kv_heads == last.kv_heads && cur.seqs == last.seqs && cur.pbs == last.pbs && cur.max_blocks == last.max_blocks && cur.cta_tile == last.cta_tile && cur.cta_threads == last.cta_threads && cur.split == last.split && cur.split_auto == last.split_auto && cur.num_splits == last.num_splits && cur.fixed == last.fixed && cur.gqa_fused == last.gqa_fused) { + return; + } + last = cur; + has_last = true; + + fprintf(stderr, + "[INFINIOP][paged_attention][hd128] dispatch: path=%s dtype=%d heads=%zu kv_heads=%zu seqs=%zu " + "pbs=%zu max_blocks=%zu cta_tile=%d cta_threads=%d split=%d split_auto=%d num_splits=%d fixed=%d gqa_fused=%d\n", + path, static_cast(dtype), num_heads, num_kv_heads, num_seqs, + page_block_size, max_num_blocks_per_seq, cta_tile, cta_threads, + static_cast(use_split), static_cast(use_split_auto), num_splits, static_cast(fixed_num_splits), + static_cast(use_gqa_fused)); + }; + + // Split-kv auto mode: decide whether to split based on a heuristic. + if (use_split_auto) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto(mode): sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + // If auto picks 1, fall back to non-split to avoid extra workspace and kernel overhead. + use_split = (num_splits > 1); + } + + // const bool debug_dispatch = [] { + // if (const char *env = std::getenv("INFINIOP_FLASH_DEBUG_DISPATCH")) { + // return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + // } + // return false; + // }(); + + // const char *selected_path = "unknown"; + + // Optional: fuse GQA groups (4) when seqlen_q=1 decode and alibi is disabled. + // This reuses K/V loads across query heads that share the same KV head. + // Controlled by INFINIOP_FLASH_GQA_FUSED (default: enabled). + if (use_gqa_fused && use_cta && !use_split && alibi_slopes == nullptr && num_kv_heads > 0 && num_heads == num_kv_heads * 4) { + dump_dispatch("cta_gqa_fused"); + dim3 grid_gqa(static_cast(num_kv_heads), static_cast(num_seqs), 1); + if (dtype == INFINI_DTYPE_F16) { + flashAttentionDecodeHd128CtaGqa4<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, nullptr, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + flashAttentionDecodeHd128CtaGqa4<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, nullptr, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + dim3 grid(static_cast(num_heads), static_cast(num_seqs), 1); + if (use_split) { + dump_dispatch(use_cta ? "splitkv_cta" : "splitkv_warp"); + // } + if (!fixed_num_splits) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + } + + const size_t n = num_seqs * num_heads; + const size_t acc_elems = static_cast(kMaxSplits) * n * 128; + const size_t m_elems = static_cast(kMaxSplits) * n; + const size_t l_elems = static_cast(kMaxSplits) * n; + const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float); + if (workspace == nullptr || workspace_size < needed_bytes) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *ws = static_cast(workspace); + float *partial_acc = ws; + float *partial_m = partial_acc + acc_elems; + float *partial_l = partial_m + m_elems; + + dim3 grid_split(static_cast(num_heads), static_cast(num_seqs), static_cast(num_splits)); + dim3 block_split(use_cta ? static_cast(cta_threads) : 32); + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_threads == 32) { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCta32Tile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta32<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } else { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCtaTile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } + } else { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + flashAttentionDecodeHd128SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_threads == 32) { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCta32Tile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta32<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } else { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCtaTile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } + } else { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + flashAttentionDecodeHd128SplitKvCombine<__nv_bfloat16><<>>( + static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + dump_dispatch(use_cta ? "cta_nosplit" : "warp_nosplit"); + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_tile == 16) { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32Tile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128CtaTile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128Cta<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } + } else { + flashAttentionDecodeHd128Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_tile == 16) { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32Tile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128CtaTile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128Cta<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } + } else { + flashAttentionDecodeHd128Warp<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +infiniStatus_t launch_decode_hd128_i64( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int64_t *block_tables, + const int64_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd128_i32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int32_t *block_tables, + const int32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd128_u32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const uint32_t *block_tables, + const uint32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +} // namespace op::paged_attention::nvidia diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu new file mode 100644 index 000000000..421fd22ef --- /dev/null +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu @@ -0,0 +1,524 @@ +#include + +#include +#include +#include +#include + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention::nvidia { + +namespace { +constexpr int kMaxSplits = 8; + +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline int getSmCount() { + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { + return 0; + } + int sm_count = 0; + if (cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device) != cudaSuccess) { + return 0; + } + return sm_count; +} + +// A lightweight FA2-style "waves" heuristic. +// +// Important: our split-kv kernel shards the KV sequence length, so the main "work" +// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k +// (max pages * page size), which matches common decode microbench where all seqs +// share the same cache length. +inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) { + if (sm_count <= 0) { + return 1; + } + if (num_heads == 0 || num_seqs == 0) { + return 1; + } + if (seqlen_k <= 256) { + return 1; + } + + const size_t base_blocks = num_heads * num_seqs; + int best_splits = 1; + // Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens. + size_t best_score = (ceilDiv(base_blocks, static_cast(sm_count)) * seqlen_k); + + size_t prev_work_per_block = seqlen_k; + for (int s = 2; s <= kMaxSplits; ++s) { + const size_t blocks = base_blocks * static_cast(s); + const size_t waves_split = ceilDiv(blocks, static_cast(sm_count)); + const size_t work_per_block = ceilDiv(seqlen_k, static_cast(s)); + // If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant. + if (work_per_block == prev_work_per_block) { + continue; + } + prev_work_per_block = work_per_block; + // Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit. + const size_t waves_combine = ceilDiv(base_blocks, static_cast(sm_count)); + const size_t score = waves_split * work_per_block + waves_combine; + if (score < best_score) { + best_score = score; + best_splits = s; + } + } + return best_splits; +} +} // namespace + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeWarpKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64Cta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Default CTA variant (lower overhead). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64CtaTile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64SplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, o_stride); +} + +template +infiniStatus_t launch_decode_hd64_impl( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + + dim3 grid(static_cast(num_heads), static_cast(num_seqs), 1); + bool use_cta = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) { + use_cta = (std::strcmp(env, "cta") == 0); + } + int cta_tile = 8; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) { + const int v = std::atoi(env); + if (v == 8 || v == 16) { + cta_tile = v; + } + } + // For head_dim=64 we use a 1-warp CTA (32 threads) with packed loads. + dim3 block(32); + + bool use_split = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 4; + bool fixed_num_splits = false; + if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) { + if (std::strcmp(env, "auto") == 0) { + fixed_num_splits = false; + } else { + num_splits = std::atoi(env); + fixed_num_splits = (num_splits > 0); + } + } + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + if (use_split) { + if (use_cta) { + // We currently only implement the split-kv path with warp kernels. + // The CTA kernel is a separate non-split implementation. + static bool warned = false; + if (!warned) { + warned = true; + fprintf(stderr, + "[INFINIOP][paged_attention] split-kv is enabled; ignoring INFINIOP_FLASH_DECODE_KERNEL=cta " + "(CTA split-kv not implemented yet)\n"); + } + } + + if (!fixed_num_splits) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + } + + const size_t n = num_seqs * num_heads; + const size_t acc_elems = static_cast(kMaxSplits) * n * 64; + const size_t m_elems = static_cast(kMaxSplits) * n; + const size_t l_elems = static_cast(kMaxSplits) * n; + const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float); + if (workspace == nullptr || workspace_size < needed_bytes) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *ws = static_cast(workspace); + float *partial_acc = ws; + float *partial_m = partial_acc + acc_elems; + float *partial_l = partial_m + m_elems; + + dim3 grid_split(static_cast(num_heads), static_cast(num_seqs), static_cast(num_splits)); + dim3 block_split(32); + + if (dtype == INFINI_DTYPE_F16) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<__nv_bfloat16><<>>( + static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_tile == 16) { + flashAttentionDecodeHd64CtaTile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd64Cta<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + flashAttentionDecodeHd64Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_tile == 16) { + flashAttentionDecodeHd64CtaTile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd64Cta<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + flashAttentionDecodeHd64Warp<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +infiniStatus_t launch_decode_hd64_i64( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int64_t *block_tables, + const int64_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd64_i32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int32_t *block_tables, + const int32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd64_u32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const uint32_t *block_tables, + const uint32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +} // namespace op::paged_attention::nvidia diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu index d544fd34a..18b6ef073 100644 --- a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu @@ -1,29 +1,68 @@ -#include +#include -#include "../../../devices/nvidia/nvidia_common.cuh" -#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include +#include +#include -#include "../../../reduce/cuda/reduce.cuh" -#include "../cuda/kernel.cuh" +#include "../../../devices/nvidia/nvidia_common.cuh" #include "paged_attention_nvidia.cuh" -template -INFINIOP_CUDA_KERNEL pagedAttention( - Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, - const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes, - const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, - const size_t block_size, - const ptrdiff_t q_stride, - const ptrdiff_t kv_block_stride, - const ptrdiff_t kv_head_stride, - const ptrdiff_t o_stride) { - op::paged_attention::cuda::pagedAttentionKernel( - out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale, - max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride); -} - namespace op::paged_attention::nvidia { +infiniStatus_t launch_decode_hd64_i64( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd64_i32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd64_u32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd128_i64( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd128_i32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd128_u32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + struct Descriptor::Opaque { std::shared_ptr internal; }; @@ -40,108 +79,284 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t cache_lens_desc, const std::optional &alibi_slopes_desc, float scale) { - auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); - CHECK_RESULT(info); + + auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale); + CHECK_RESULT(info_res); + auto info = info_res.take(); + // Reserve workspace for optional split-kv decode (partial acc + m/l). + // Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits. + constexpr size_t kMaxSplits = 8; + const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float); + const size_t workspace_bytes = kMaxSplits * per_split; + *desc_ptr = new Descriptor( new Opaque{reinterpret_cast(handle)->internal()}, - info.take(), 0, handle->device, handle->device_id); + info, workspace_bytes, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } -template -infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache, - infiniDtype_t dtype, - const void *block_tables, const void *seq_lens, const void *alibi_slopes, - size_t num_heads, size_t num_seqs, - size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, - ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride, - cudaStream_t stream) { - dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1); - dim3 block(NUM_THREADS); - size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); - - if (dtype == INFINI_DTYPE_F16) { - pagedAttention - <<>>( - (half *)out, - (const half *)q, (const half *)k_cache, (const half *)v_cache, - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, - scale, max_num_blocks_per_seq, block_size, - q_stride, kv_block_stride, kv_head_stride, o_stride); - } else if (dtype == INFINI_DTYPE_BF16) { - pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS> - <<>>( - (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache, - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, - scale, max_num_blocks_per_seq, block_size, - q_stride, kv_block_stride, kv_head_stride, o_stride); - } else if (dtype == INFINI_DTYPE_F32) { - pagedAttention - <<>>( - (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache, - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, - scale, max_num_blocks_per_seq, block_size, - q_stride, kv_block_stride, kv_head_stride, o_stride); - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - return INFINI_STATUS_SUCCESS; -} - infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *out, const void *q, const void *k_cache, const void *v_cache, - const void *block_tables, const void *seq_lens, const void *alibi_slopes, + const void *block_tables, const void *cache_lens, const void *alibi_slopes, void *stream_) const { - cudaStream_t stream = (cudaStream_t)stream_; - -#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \ - launchKernel<__H_SIZE, __B_SIZE>( \ - out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \ - _info.num_heads, _info.num_seqs, \ - _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \ - _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \ - stream); - -#define SWITCH_HEAD_SIZE(__B_SIZE) \ - switch (_info.head_size) { \ - case 16: \ - LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \ - break; \ - case 32: \ - LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \ - break; \ - case 64: \ - LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \ - break; \ - case 128: \ - LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \ - break; \ - case 256: \ - LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \ - break; \ - default: \ - return INFINI_STATUS_BAD_TENSOR_SHAPE; \ - } - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { - SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024) - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { - SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512) - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { - SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096) + bool need_workspace = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + // "auto" may enable split-kv depending on the runtime heuristic. + need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); } else { - return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + // Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace. + need_workspace = (_info.head_size == 128); + } + if (need_workspace && workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } -#undef LAUNCH_HEADSIZE_BLOCKSIZE -#undef SWITCH_HEAD_SIZE + auto stream = static_cast(stream_); - return INFINI_STATUS_SUCCESS; + const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast(alibi_slopes); + + if (_info.index_dtype == INFINI_DTYPE_I64) { + const auto *block_table_i64 = static_cast(block_tables); + const auto *cache_lens_i64 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_i64( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i64, cache_lens_i64, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_i64( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i64, cache_lens_i64, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (_info.index_dtype == INFINI_DTYPE_I32) { + const auto *block_table_i32 = static_cast(block_tables); + const auto *cache_lens_i32 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_i32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i32, cache_lens_i32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_i32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i32, cache_lens_i32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (_info.index_dtype == INFINI_DTYPE_U32) { + const auto *block_table_u32 = static_cast(block_tables); + const auto *cache_lens_u32 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_u32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_u32, cache_lens_u32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_u32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_u32, cache_lens_u32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; } } // namespace op::paged_attention::nvidia + +// #include + +// #include "../../../devices/nvidia/nvidia_common.cuh" +// #include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +// #include "../../../reduce/cuda/reduce.cuh" +// #include "../cuda/kernel.cuh" +// #include "paged_attention_nvidia.cuh" + +// template +// INFINIOP_CUDA_KERNEL pagedAttention( +// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, +// const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes, +// const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, +// const size_t block_size, +// const ptrdiff_t q_stride, +// const ptrdiff_t kv_block_stride, +// const ptrdiff_t kv_head_stride, +// const ptrdiff_t o_stride) { +// op::paged_attention::cuda::pagedAttentionKernel( +// out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale, +// max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride); +// } + +// namespace op::paged_attention::nvidia { + +// struct Descriptor::Opaque { +// std::shared_ptr internal; +// }; + +// Descriptor::~Descriptor() { +// delete _opaque; +// } + +// infiniStatus_t Descriptor::create( +// infiniopHandle_t handle, +// Descriptor **desc_ptr, +// infiniopTensorDescriptor_t out_desc, +// infiniopTensorDescriptor_t q_desc, +// infiniopTensorDescriptor_t k_cache_desc, +// infiniopTensorDescriptor_t v_cache_desc, +// infiniopTensorDescriptor_t block_tables_desc, +// infiniopTensorDescriptor_t seq_lens_desc, +// const std::optional &alibi_slopes_desc, +// float scale) { +// auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); +// CHECK_RESULT(info); +// *desc_ptr = new Descriptor( +// new Opaque{reinterpret_cast(handle)->internal()}, +// info.take(), 0, handle->device, handle->device_id); + +// return INFINI_STATUS_SUCCESS; +// } + +// template +// infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache, +// infiniDtype_t dtype, +// const void *block_tables, const void *seq_lens, const void *alibi_slopes, +// size_t num_heads, size_t num_seqs, +// size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, +// ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride, +// cudaStream_t stream) { +// dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1); +// dim3 block(NUM_THREADS); +// size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); + +// if (dtype == INFINI_DTYPE_F16) { +// pagedAttention +// <<>>( +// (half *)out, +// (const half *)q, (const half *)k_cache, (const half *)v_cache, +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, +// scale, max_num_blocks_per_seq, block_size, +// q_stride, kv_block_stride, kv_head_stride, o_stride); +// } else if (dtype == INFINI_DTYPE_BF16) { +// pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS> +// <<>>( +// (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache, +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, +// scale, max_num_blocks_per_seq, block_size, +// q_stride, kv_block_stride, kv_head_stride, o_stride); +// } else if (dtype == INFINI_DTYPE_F32) { +// pagedAttention +// <<>>( +// (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache, +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, +// scale, max_num_blocks_per_seq, block_size, +// q_stride, kv_block_stride, kv_head_stride, o_stride); +// } else { +// return INFINI_STATUS_BAD_TENSOR_DTYPE; +// } +// return INFINI_STATUS_SUCCESS; +// } + +// infiniStatus_t Descriptor::calculate( +// void *workspace, size_t workspace_size, +// void *out, const void *q, const void *k_cache, const void *v_cache, +// const void *block_tables, const void *seq_lens, const void *alibi_slopes, +// void *stream_) const { +// cudaStream_t stream = (cudaStream_t)stream_; + +// #define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \ +// launchKernel<__H_SIZE, __B_SIZE>( \ +// out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \ +// _info.num_heads, _info.num_seqs, \ +// _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \ +// _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \ +// stream); + +// #define SWITCH_HEAD_SIZE(__B_SIZE) \ +// switch (_info.head_size) { \ +// case 16: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \ +// break; \ +// case 32: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \ +// break; \ +// case 64: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \ +// break; \ +// case 128: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \ +// break; \ +// case 256: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \ +// break; \ +// default: \ +// return INFINI_STATUS_BAD_TENSOR_SHAPE; \ +// } + +// if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { +// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024) +// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { +// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512) +// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { +// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096) +// } else { +// return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; +// } + +// #undef LAUNCH_HEADSIZE_BLOCKSIZE +// #undef SWITCH_HEAD_SIZE + +// return INFINI_STATUS_SUCCESS; +// } + +// } // namespace op::paged_attention::nvidia diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh new file mode 100644 index 000000000..6790f12d8 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh @@ -0,0 +1,2361 @@ +#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ +#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ + +#include +#include +#include +#include + +#include +#include + +// Reuse warp-level primitives and math helpers from decode flash_attention kernels. +#include "../../paged_attention/cuda/kernel_v2.cuh" + +namespace op::paged_attention_prefill::cuda { + +__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cu_seqlens_q, size_t num_seqs) { + size_t low = 0, high = (num_seqs == 0) ? 0 : (num_seqs - 1); + while (low <= high) { + size_t mid = (low + high) >> 1; + const size_t start = static_cast(cu_seqlens_q[mid]); + const size_t end = static_cast(cu_seqlens_q[mid + 1]); + if (token_idx >= start && token_idx < end) { + return mid; + } else if (token_idx < start) { + if (mid == 0) { + break; + } + high = mid - 1; + } else { + low = mid + 1; + } + } + return 0; +} + +template +__device__ void PagedAttentionPrefillWarpKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x; + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int q_token_local = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_token_local >= q_len) { + return; + } + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = history_len + q_token_local + 1; + if (allowed_k_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const int64_t q_token = q_start + static_cast(q_token_local); + const Tdata *q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + Tdata *out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + const int pbs = static_cast(page_block_size); + int t_base = 0; + for (int logical_block = 0; t_base < allowed_k_len; ++logical_block, t_base += pbs) { + int physical_block = 0; + if (lane == 0) { + physical_block = static_cast(block_table[logical_block]); + } + physical_block = __shfl_sync(0xffffffff, physical_block, 0); + + const Tdata *k_base = k_cache_ + static_cast(physical_block) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(physical_block) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, allowed_k_len - t_base); + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int t = t_base + token_in_block; + const Tdata *k_ptr = k_base + static_cast(token_in_block) * k_row_stride; + const Tdata *v_ptr = v_base + static_cast(token_in_block) * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(t - causal_limit)) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float o = acc[i] * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +template +__global__ void PagedAttentionPrefillWarpGlobalKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x; + const size_t head_idx = static_cast(blockIdx.x); + const size_t global_token_idx = static_cast(blockIdx.y); + + if (lane >= kWarpSize || head_idx >= num_heads || global_token_idx >= total_q_tokens) { + return; + } + + const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + + const int q_token_local = static_cast(global_token_idx - static_cast(q_start)); + if (q_token_local < 0 || q_token_local >= q_len) { + return; + } + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = history_len + q_token_local + 1; + if (allowed_k_len <= 0) { + return; + } + + const int num_queries_per_kv = static_cast(num_heads / num_kv_heads); + const int kv_head_idx = static_cast(head_idx) / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tdata *q_ptr = q_ + static_cast(global_token_idx) * q_stride + static_cast(head_idx) * q_head_stride; + Tdata *out_ptr = out_ + static_cast(global_token_idx) * o_stride + static_cast(head_idx) * o_head_stride; + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + const int pbs = static_cast(page_block_size); + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + // Iterate by pages to avoid per-token division/mod and redundant block_table loads. + int t_base = 0; + for (int logical_block = 0; t_base < allowed_k_len; ++logical_block, t_base += pbs) { + const int32_t phys = static_cast(block_table[logical_block]); + const Tdata *k_base = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, allowed_k_len - t_base); + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int t = t_base + token_in_block; + const Tdata *k_ptr = k_base + static_cast(token_in_block) * k_row_stride; + const Tdata *v_ptr = v_base + static_cast(token_in_block) * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(t - causal_limit)) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float o = acc[i] * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +template +__global__ void PagedAttentionPrefillReferenceKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_heads, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + size_t num_seqs) { + + const size_t global_token_idx = static_cast(blockIdx.x); + const size_t head_idx = static_cast(blockIdx.y); + const size_t dim_idx = static_cast(threadIdx.x); + + if (dim_idx >= HEAD_SIZE || head_idx >= num_heads) { + return; + } + + const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); + const size_t q_token_idx = global_token_idx - static_cast(cu_seqlens_q_[seq_idx]); + const size_t q_len = static_cast(cu_seqlens_q_[seq_idx + 1] - cu_seqlens_q_[seq_idx]); + + const size_t total_kv_len = static_cast(total_kv_lens_[seq_idx]); + const size_t history_len = total_kv_len - q_len; + const size_t causal_limit = history_len + q_token_idx; + + const size_t num_queries_per_kv = num_heads / num_kv_heads; + const size_t kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + + const Tdata *q_vec = q_ + static_cast(global_token_idx) * q_stride + static_cast(head_idx) * q_head_stride; + Tdata *out_ptr = out_ + static_cast(global_token_idx) * o_stride + static_cast(head_idx) * o_head_stride; + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + const size_t pbs = page_block_size; + + Tcompute max_score = -INFINITY; + for (size_t t = 0; t <= causal_limit; ++t) { + const size_t page = t / pbs; + const size_t off = t - page * pbs; + const ptrdiff_t phys = static_cast(block_table[page]); + const Tdata *k_vec = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + + Tcompute score = 0; + for (size_t d = 0; d < HEAD_SIZE; ++d) { + score += static_cast(q_vec[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += static_cast(alibi_slope * static_cast(t - causal_limit)); + } + if (score > max_score) { + max_score = score; + } + } + + Tcompute sum_exp = 0; + for (size_t t = 0; t <= causal_limit; ++t) { + const size_t page = t / pbs; + const size_t off = t - page * pbs; + const ptrdiff_t phys = static_cast(block_table[page]); + const Tdata *k_vec = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + + Tcompute score = 0; + for (size_t d = 0; d < HEAD_SIZE; ++d) { + score += static_cast(q_vec[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += static_cast(alibi_slope * static_cast(t - causal_limit)); + } + sum_exp += static_cast(expf(static_cast(score - max_score))); + } + + const Tcompute inv_sum = static_cast(1.0f) / (sum_exp + static_cast(1e-6f)); + Tcompute acc = 0; + for (size_t t = 0; t <= causal_limit; ++t) { + const size_t page = t / pbs; + const size_t off = t - page * pbs; + const ptrdiff_t phys = static_cast(block_table[page]); + const Tdata *k_vec = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + + Tcompute score = 0; + for (size_t d = 0; d < HEAD_SIZE; ++d) { + score += static_cast(q_vec[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += static_cast(alibi_slope * static_cast(t - causal_limit)); + } + const Tcompute prob = static_cast(expf(static_cast(score - max_score))) * inv_sum; + + const Tdata *v_vec = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + acc += prob * static_cast(v_vec[dim_idx]); + } + + out_ptr[dim_idx] = static_cast(acc); +} + +template +__device__ void PagedAttentionPrefillWarpCtaKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be small (warp-per-query design)."); + static_assert(BLOCK_N == 64 || BLOCK_N == 128, "BLOCK_N must be 64/128 in v0.4."); + + constexpr int kWarpSize = 32; + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + // IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads() + // later. Tail tiles are handled by masking inactive warps. + if (m_start >= q_len) { + return; // uniform across the CTA + } + const bool is_active = (q_token_local < q_len); + + const int64_t kv_len_total_i64 = total_kv_lens_[seq_idx]; + const int kv_len_total = static_cast(kv_len_total_i64); + // history_len = total_kv_len - q_len (KV already includes current q tokens). + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + const Tdata *q_ptr = nullptr; + Tdata *out_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]); + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + // For this CTA, we only need to scan up to the max allowed k among active warps. + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + + __shared__ int32_t s_phys[BLOCK_N]; + __shared__ int32_t s_off[BLOCK_N]; + // Ensure shared-memory tiles are aligned for half2/bfloat162 vector loads. + __shared__ __align__(16) Tdata s_k[BLOCK_N * HEAD_SIZE]; + __shared__ __align__(16) Tdata s_v[BLOCK_N * HEAD_SIZE]; + + const int pbs = static_cast(page_block_size); + + for (int k_base = 0; k_base < max_allowed_k_len; k_base += BLOCK_N) { + const int tile_n = min(BLOCK_N, max_allowed_k_len - k_base); + + // Precompute page mapping once per token in the tile. + for (int t = threadIdx.x; t < tile_n; t += blockDim.x) { + const int kpos = k_base + t; + const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs); + const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs); + const int32_t phys = static_cast(block_table[page]); + s_phys[t] = phys; + s_off[t] = off; + } + __syncthreads(); + + // Load K/V tile into shared memory (contiguous in head_dim). + const int tile_elems = tile_n * HEAD_SIZE; + for (int idx = threadIdx.x; idx < tile_elems; idx += blockDim.x) { + const int t = idx / HEAD_SIZE; + const int dim = idx - t * HEAD_SIZE; + const int32_t phys = s_phys[t]; + const int32_t off = s_off[t]; + const Tdata *k_base_ptr = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base_ptr = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + s_k[t * HEAD_SIZE + dim] = k_base_ptr[dim]; + s_v[t * HEAD_SIZE + dim] = v_base_ptr[dim]; + } + __syncthreads(); + + // Each warp processes one query token and scans the K/V tile. + for (int t = 0; t < tile_n; ++t) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + break; + } + const Tdata *k_ptr = s_k + t * HEAD_SIZE; + const Tdata *v_ptr = s_v + t * HEAD_SIZE; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + // Causal prefill: last position is (allowed_k_len - 1) for this query. + score += (alibi_slope * static_cast(kpos - (allowed_k_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + + __syncthreads(); + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float out_val = acc[i] * inv_l; + if (!is_active) { + continue; + } + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(out_val); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(out_val); + } else { + out_ptr[dim] = static_cast(out_val); + } + } +} + +// Pipelined CTA kernel (FA2-style): stage K/V loads with cp.async and overlap global->shared +// copies with compute. +// +// Design notes: +// - Keep shared memory <= 48KB for compatibility with multi-arch builds that include SM75. +// - Iterate by paged blocks (logical pages) so each tile stays within one physical block and +// avoids per-token (page, off) mapping arrays in shared memory. +// - One warp computes one query token (same as warpcta kernels). Warps with shorter causal +// limits simply mask the tail tokens but still participate in CTA-wide barriers. +template +__device__ void PagedAttentionPrefillWarpCtaKernelPipelined( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <= 16."); + static_assert(TOKENS_PER_TILE == 32, "Pipelined CTA kernel currently assumes TOKENS_PER_TILE == 32."); + static_assert(STAGES >= 2 && STAGES <= 3, "STAGES must be 2 or 3."); + static_assert(sizeof(Tdata) == 2, "Pipelined CTA kernel supports only fp16/bf16."); + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + // Uniform return for empty tail CTAs (avoid deadlock with __syncthreads). + if (m_start >= q_len) { + return; + } + const bool is_active = (q_token_local < q_len); + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + const Tdata *q_ptr = nullptr; + Tdata *out_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]); + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + // For this CTA, scan KV up to the max causal limit among active warps. + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + if (max_allowed_k_len <= 0) { + // Nothing to attend to (should be rare). Produce zeros. + if (is_active) { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + out_ptr[dim] = Tdata{}; + } + } + return; + } + + // cp.async uses 16B chunks; for fp16/bf16 that's 8 elements. + constexpr int CHUNK_ELEMS = 8; + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + // Multi-stage pipeline buffers. + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + // Per-warp scratch for tile-wise softmax (scores over TOKENS_PER_TILE). + // We keep scores in shared so each lane can load its token score (lane -> token index), + // then weights are broadcast via warp shuffles to avoid extra shared-memory traffic. + __shared__ float sh_scores[BLOCK_M][TOKENS_PER_TILE]; + // Store Q in shared (per warp). This enables more tile-level parallelism in score + // computation without expensive cross-lane shuffles of Q registers. + __shared__ __align__(16) Tdata sh_q[BLOCK_M][HEAD_SIZE]; + + const int pbs = static_cast(page_block_size); + const int tid = threadIdx.x; + + // Populate per-warp Q shared tile once. +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + sh_q[warp_id][dim] = is_active ? q_ptr[dim] : Tdata{}; + } + __syncwarp(); + + int t_base = 0; + for (int logical_block = 0; t_base < max_allowed_k_len; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + static_cast(physical_block) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(physical_block) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, max_allowed_k_len - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + static_cast(token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_in_block + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + __syncthreads(); + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + const int global_k_base = t_base + token_in_block; + // Tile-wise online softmax (more FA2-like than per-token update): + // 1) Compute scores for this tile (masked to each warp's causal limit). + // 2) Compute tile max + sumexp. + // 3) Accumulate weighted V for the tile. + // 4) Merge into running (m, l, acc) in a numerically stable way. + // + // NOTE: this does not yet implement MMA / full tile-level GEMM; it mainly reduces + // the serial (lane0) online-softmax update frequency from per-token to per-tile. + float alpha = 1.0f; + float beta = 0.0f; + float tile_sumexp = 0.0f; + float tile_m = -INFINITY; + + if (allowed_k_len > 0) { + // 1) scores + // Increase tile-level parallelism vs the previous per-token loop: + // split the warp into 4 groups of 8 lanes; each group computes one token score in parallel. + constexpr int LANES_PER_GROUP = 8; + constexpr int GROUPS_PER_WARP = 4; + constexpr int DIMS_PER_GROUP_LANE = HEAD_SIZE / LANES_PER_GROUP; + static_assert(HEAD_SIZE % LANES_PER_GROUP == 0, "HEAD_SIZE must be divisible by 8."); + + const int group_id = lane / LANES_PER_GROUP; // [0..3] + const int lane_g = lane & (LANES_PER_GROUP - 1); // [0..7] + const unsigned int group_mask = 0xFFu << (group_id * LANES_PER_GROUP); + + for (int j_base = 0; j_base < TOKENS_PER_TILE; j_base += GROUPS_PER_WARP) { + const int j = j_base + group_id; // token index in [0..31] + const int kpos = global_k_base + j; + + const bool token_in_tile = (j < tile_n); + const bool token_unmasked = token_in_tile && (kpos < allowed_k_len); + + float qk_part = 0.0f; + if (token_unmasked) { + const Tdata *k_ptr = &sh_k[buf][j][0]; + const int dim_base = lane_g * DIMS_PER_GROUP_LANE; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const half2 *q2 = reinterpret_cast(&sh_q[warp_id][dim_base]); + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int t = 0; t < DIMS_PER_GROUP_LANE / 2; ++t) { + const float2 qf = __half22float2(q2[t]); + const float2 kf = __half22float2(k2[t]); + qk_part += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const __nv_bfloat162 *q2 = reinterpret_cast(&sh_q[warp_id][dim_base]); + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int t = 0; t < DIMS_PER_GROUP_LANE / 2; ++t) { + const float2 qf = __bfloat1622float2(q2[t]); + const float2 kf = __bfloat1622float2(k2[t]); + qk_part += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int t = 0; t < DIMS_PER_GROUP_LANE; ++t) { + qk_part += static_cast(sh_q[warp_id][dim_base + t]) * static_cast(k_ptr[dim_base + t]); + } + } + } + + // Reduce within 8-lane group. + for (int offset = LANES_PER_GROUP / 2; offset > 0; offset >>= 1) { + qk_part += __shfl_down_sync(group_mask, qk_part, offset, LANES_PER_GROUP); + } + + if (lane_g == 0) { + float score = -INFINITY; + if (token_unmasked) { + score = qk_part * scale_log2; + if (alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(kpos - causal_limit)) * kLog2e; + } + } + sh_scores[warp_id][j] = score; + } + } + __syncwarp(); + + // 2) tile max + sumexp (lane t corresponds to token t within the tile) + const float score_lane = (lane < tile_n) ? sh_scores[warp_id][lane] : -INFINITY; + float tile_m_tmp = op::paged_attention::cuda::warpReduceMax(score_lane); + tile_m_tmp = __shfl_sync(0xffffffff, tile_m_tmp, 0); + tile_m = tile_m_tmp; + + float w_lane = 0.0f; + if (lane < tile_n && tile_m != -INFINITY) { + w_lane = exp2f(score_lane - tile_m); + } + float sumexp_tmp = op::paged_attention::cuda::warpReduceSum(w_lane); + sumexp_tmp = __shfl_sync(0xffffffff, sumexp_tmp, 0); + tile_sumexp = sumexp_tmp; + + // 3) weighted V for this tile (per lane owns HEAD_SIZE/32 dims) + float acc_tile[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc_tile[i] = 0.0f; + } + + if (tile_sumexp > 0.0f) { + for (int j = 0; j < tile_n; ++j) { + // Broadcast weight for token j from lane j. + const float wj = __shfl_sync(0xffffffff, w_lane, j); + const Tdata *v_ptr = &sh_v[buf][j][0]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __half22float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __bfloat1622float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + acc_tile[i] += wj * static_cast(v_ptr[dim]); + } + } + } + } + + // 4) merge tile into running (m, l, acc) + if (lane == 0) { + if (tile_sumexp > 0.0f && tile_m != -INFINITY) { + const float m_new = fmaxf(m, tile_m); + alpha = exp2f(m - m_new); + beta = exp2f(tile_m - m_new); + l = l * alpha + tile_sumexp * beta; + m = m_new; + } else { + alpha = 1.0f; + beta = 0.0f; + } + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc[i] = acc[i] * alpha + beta * acc_tile[i]; + } + } + + // IMPORTANT: warps in this CTA can have different allowed_k_len (due to causal mask + history), + // so they may finish the token loop at different times. We must not start prefetching into + // the circular shared-memory buffer until all warps finish consuming the current tile. + __syncthreads(); + + // Prefetch the tile that will reuse this buffer (STAGES steps ahead). + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + static_cast(token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_prefetch + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + __syncthreads(); + } + } + + op::paged_attention::cuda::cpAsyncWaitAll(); + __syncthreads(); + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float out_val = acc[i] * inv_l; + if (!is_active) { + continue; + } + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(out_val); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(out_val); + } else { + out_ptr[dim] = static_cast(out_val); + } + } +} + +// Split-KV prefill (FA2-style): each split scans a shard of KV and writes partial (m, l, acc) +// to workspace. A separate combine kernel merges splits into the final output. +// +// Notes: +// - Implemented for the pipelined CTA kernel family (warpcta8pipe). We split by logical paged blocks. +// - Each warp still applies its own causal limit (allowed_k_len) so correctness is preserved. +template +__device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + float *partial_acc, // [num_splits, total_q_tokens, num_heads, head_size] + float *partial_m, // [num_splits, total_q_tokens, num_heads] + float *partial_l, // [num_splits, total_q_tokens, num_heads] + int split_idx, + int num_splits, + int m_block, + size_t total_q_tokens, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + + (void)max_num_blocks_per_seq; + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <= 16."); + static_assert(TOKENS_PER_TILE == 32, "Split-KV prefill assumes TOKENS_PER_TILE == 32."); + static_assert(STAGES >= 2 && STAGES <= 3, "STAGES must be 2 or 3."); + static_assert(sizeof(Tdata) == 2, "Split-KV prefill supports only fp16/bf16."); + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + if (m_start >= q_len) { + return; // uniform + } + const bool is_active = (q_token_local < q_len); + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const size_t n = total_q_tokens * static_cast(num_heads); + size_t base = 0; + if (is_active) { + base = static_cast(q_token) * static_cast(num_heads) + static_cast(head_idx); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + const Tdata *q_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + + float m = -INFINITY; + float l = 0.0f; + + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + if (max_allowed_k_len <= 0) { + if (is_active) { + const size_t idx = static_cast(split_idx) * n + base; + if (lane == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = 0.0f; + } + } + return; + } + + const int pbs = static_cast(page_block_size); + const int num_blocks_total = (max_allowed_k_len + pbs - 1) / pbs; + const int blocks_per_split = (num_blocks_total + num_splits - 1) / num_splits; + const int start_block = split_idx * blocks_per_split; + const int end_block = min(num_blocks_total, start_block + blocks_per_split); + if (start_block >= end_block) { + if (is_active) { + const size_t idx = static_cast(split_idx) * n + base; + if (lane == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = 0.0f; + } + } + return; + } + + const int max_allowed_k_len_split = min(max_allowed_k_len, end_block * pbs); + + constexpr int CHUNK_ELEMS = 8; + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ float sh_scores[BLOCK_M][TOKENS_PER_TILE]; + + const int tid = threadIdx.x; + + int t_base = start_block * pbs; + for (int logical_block = start_block; t_base < max_allowed_k_len_split; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + static_cast(physical_block) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(physical_block) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, max_allowed_k_len_split - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + static_cast(token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_in_block + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + __syncthreads(); + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + const int global_k_base = t_base + token_in_block; + + float alpha = 1.0f; + float beta = 0.0f; + float tile_sumexp = 0.0f; + float tile_m = -INFINITY; + float w_lane = 0.0f; + + if (allowed_k_len > 0) { + // 1) scores + for (int j = 0; j < tile_n; ++j) { + const int kpos = global_k_base + j; + const bool token_unmasked = (kpos < allowed_k_len); + float qk = 0.0f; + if (token_unmasked) { + const Tdata *k_ptr = &sh_k[buf][j][0]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int ii = 0; ii < DIMS_PER_THREAD / 2; ++ii) { + const float2 qf = __half22float2(q2[ii]); + const float2 kf = __half22float2(k2[ii]); + qk = fmaf(qf.x, kf.x, qk); + qk = fmaf(qf.y, kf.y, qk); + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int ii = 0; ii < DIMS_PER_THREAD / 2; ++ii) { + const float2 qf = __bfloat1622float2(q2[ii]); + const float2 kf = __bfloat1622float2(k2[ii]); + qk = fmaf(qf.x, kf.x, qk); + qk = fmaf(qf.y, kf.y, qk); + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk = fmaf(q_reg[i], static_cast(k_ptr[dim]), qk); + } + } + } + qk = op::paged_attention::cuda::warpReduceSum(qk); + if (lane == 0) { + float score = token_unmasked ? (qk * scale_log2) : -INFINITY; + if (token_unmasked && alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(kpos - causal_limit)) * kLog2e; + } + sh_scores[warp_id][j] = score; + } + } + __syncwarp(); + + // 2) tile max / sumexp + float max_tmp = -INFINITY; + if (lane < tile_n) { + max_tmp = sh_scores[warp_id][lane]; + } + max_tmp = op::paged_attention::cuda::warpReduceMax(max_tmp); + max_tmp = __shfl_sync(0xffffffff, max_tmp, 0); + tile_m = max_tmp; + + if (lane < tile_n) { + const float s = sh_scores[warp_id][lane]; + w_lane = (s == -INFINITY) ? 0.0f : exp2f(s - tile_m); + } else { + w_lane = 0.0f; + } + float sumexp_tmp = op::paged_attention::cuda::warpReduceSum(w_lane); + sumexp_tmp = __shfl_sync(0xffffffff, sumexp_tmp, 0); + tile_sumexp = sumexp_tmp; + + // 3) weighted V for this tile + float acc_tile[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc_tile[i] = 0.0f; + } + if (tile_sumexp > 0.0f) { + for (int j = 0; j < tile_n; ++j) { + const float wj = __shfl_sync(0xffffffff, w_lane, j); + const Tdata *v_ptr = &sh_v[buf][j][0]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __half22float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __bfloat1622float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + acc_tile[i] += wj * static_cast(v_ptr[dim]); + } + } + } + } + + // 4) merge tile into running (m, l, acc) + if (lane == 0) { + if (tile_sumexp > 0.0f && tile_m != -INFINITY) { + const float m_new = fmaxf(m, tile_m); + alpha = exp2f(m - m_new); + beta = exp2f(tile_m - m_new); + l = l * alpha + tile_sumexp * beta; + m = m_new; + } else { + alpha = 1.0f; + beta = 0.0f; + } + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc[i] = acc[i] * alpha + beta * acc_tile[i]; + } + } + + __syncthreads(); + + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + static_cast(token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_prefetch + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + __syncthreads(); + } + } + + op::paged_attention::cuda::cpAsyncWaitAll(); + __syncthreads(); + } + + if (is_active) { + const size_t idx = static_cast(split_idx) * n + base; + if (lane == 0) { + partial_m[idx] = m; + partial_l[idx] = l; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = acc[i]; + } + } +} + +template +__device__ void PagedAttentionPrefillSplitKvCombineWarpKernel( + Tdata *out_, + const float *partial_acc, // [num_splits, total_q_tokens, num_heads, head_size] + const float *partial_m, // [num_splits, total_q_tokens, num_heads] + const float *partial_l, // [num_splits, total_q_tokens, num_heads] + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + const int head_idx = static_cast(blockIdx.x); + const int token_idx = static_cast(blockIdx.y); + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int num_heads = gridDim.x; + const size_t n = total_q_tokens * static_cast(num_heads); + const size_t base = static_cast(token_idx) * static_cast(num_heads) + static_cast(head_idx); + + float m = -INFINITY; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + m = fmaxf(m, partial_m[static_cast(s) * n + base]); + } + } + m = __shfl_sync(0xffffffff, m, 0); + + float l = 0.0f; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[static_cast(s) * n + base]; + const float ls = partial_l[static_cast(s) * n + base]; + if (ls > 0.0f) { + l += ls * exp2f(ms - m); + } + } + } + l = __shfl_sync(0xffffffff, l, 0); + const float inv_l = 1.0f / (l + 1e-6f); + + Tdata *out_ptr = out_ + static_cast(token_idx) * o_stride + static_cast(head_idx) * o_head_stride; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + float acc = 0.0f; + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[static_cast(s) * n + base]; + const float w = exp2f(ms - m); + acc += partial_acc[(static_cast(s) * n + base) * HEAD_SIZE + dim] * w; + } + const float o = acc * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +// Variant for large K tile where (K+V) shared memory would exceed the per-block limit on some GPUs. +// We keep K in shared memory for reuse across warps, but load V directly from global memory. +template +__device__ void PagedAttentionPrefillWarpCtaKernelKOnly( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <=16."); + static_assert(BLOCK_N > 0 && BLOCK_N <= 128, "BLOCK_N must be <=128."); + + constexpr int kWarpSize = 32; + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + // IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads() + // later. Tail tiles are handled by masking inactive warps. + if (m_start >= q_len) { + return; // uniform across the CTA + } + const bool is_active = (q_token_local < q_len); + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + const Tdata *q_ptr = nullptr; + Tdata *out_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]); + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + + __shared__ int32_t s_phys[BLOCK_N]; + __shared__ int32_t s_off[BLOCK_N]; + __shared__ __align__(16) Tdata s_k[BLOCK_N * HEAD_SIZE]; + + const int pbs = static_cast(page_block_size); + + for (int k_base = 0; k_base < max_allowed_k_len; k_base += BLOCK_N) { + const int tile_n = min(BLOCK_N, max_allowed_k_len - k_base); + + for (int t = threadIdx.x; t < tile_n; t += blockDim.x) { + const int kpos = k_base + t; + const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs); + const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs); + const int32_t phys = static_cast(block_table[page]); + s_phys[t] = phys; + s_off[t] = off; + } + __syncthreads(); + + const int tile_elems = tile_n * HEAD_SIZE; + for (int idx = threadIdx.x; idx < tile_elems; idx += blockDim.x) { + const int t = idx / HEAD_SIZE; + const int dim = idx - t * HEAD_SIZE; + const int32_t phys = s_phys[t]; + const int32_t off = s_off[t]; + const Tdata *k_base_ptr = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + s_k[t * HEAD_SIZE + dim] = k_base_ptr[dim]; + } + __syncthreads(); + + for (int t = 0; t < tile_n; ++t) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + break; + } + const Tdata *k_ptr = s_k + t * HEAD_SIZE; + const int32_t phys = s_phys[t]; + const int32_t off = s_off[t]; + const Tdata *v_ptr = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(kpos - (allowed_k_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + + __syncthreads(); + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float out_val = acc[i] * inv_l; + if (!is_active) { + continue; + } + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(out_val); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(out_val); + } else { + out_ptr[dim] = static_cast(out_val); + } + } +} + +// TensorCore (WMMA) score kernel (v0.4 experimental): +// - Target shape: head_dim=128, page_block_size=256, fp16. +// - Compute QK^T with WMMA into shared memory, then reuse the existing online-softmax + V accumulation +// pattern (SIMT) per query row. +// +// Notes: +// - This is a correctness-first kernel. It doesn't yet use MMA for PV (P * V) update. +// - We keep the same grid mapping as other prefill kernels: blockIdx = (head, seq, m_block). +template +__device__ __forceinline__ void PagedAttentionPrefillMmaScoreUpdateRow( + int lane, + int k_base, + int allowed_k_len, + const float *scores_row, // [kBlockN] + const half *v_tile, // [kBlockN, kHeadDim] + float scale_log2, + float alibi_slope_log2, + float &m, + float &l, + float *acc) { // [kDimsPerThread] + + // Max over keys in this tile. + float local_max = -INFINITY; + for (int t = lane; t < kBlockN; t += kWarpSize) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + continue; + } + float score = scores_row[t] * scale_log2; + if (alibi_slope_log2 != 0.0f) { + score += alibi_slope_log2 * static_cast(kpos - (allowed_k_len - 1)); + } + local_max = fmaxf(local_max, score); + } + float tile_m = op::paged_attention::cuda::warpReduceMax(local_max); + tile_m = __shfl_sync(0xffffffff, tile_m, 0); + + // Sumexp + weighted V over keys in this tile, partitioned by lanes. + float sumexp_lane = 0.0f; + float acc_tile[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f}; + const int dim_base = lane * kDimsPerThread; + if (tile_m != -INFINITY) { + for (int t = lane; t < kBlockN; t += kWarpSize) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + continue; + } + float score = scores_row[t] * scale_log2; + if (alibi_slope_log2 != 0.0f) { + score += alibi_slope_log2 * static_cast(kpos - (allowed_k_len - 1)); + } + const float w = exp2f(score - tile_m); + sumexp_lane += w; + + const half *v_ptr = v_tile + t * kHeadDim + dim_base; + const half2 *v2 = reinterpret_cast(v_ptr); +#pragma unroll + for (int j = 0; j < kDimsPerThread / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc_tile[j * 2 + 0] += w * vf.x; + acc_tile[j * 2 + 1] += w * vf.y; + } + } + } + + float tile_sumexp = op::paged_attention::cuda::warpReduceSum(sumexp_lane); + tile_sumexp = __shfl_sync(0xffffffff, tile_sumexp, 0); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + if (tile_sumexp > 0.0f && tile_m != -INFINITY) { + const float m_new = fmaxf(m, tile_m); + alpha = exp2f(m - m_new); + beta = exp2f(tile_m - m_new); + l = l * alpha + tile_sumexp * beta; + m = m_new; + } else { + alpha = 1.0f; + beta = 0.0f; + } + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + acc[i] = acc[i] * alpha + beta * acc_tile[i]; + } +} + +template +__device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow( + int lane, + bool active, + int q_token_local, + int64_t q_start, + int head_idx, + half *out_, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + float l, + const float *acc) { // [kDimsPerThread] + if (!active) { + return; + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + + const int64_t q_token = q_start + static_cast(q_token_local); + half *out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + out_ptr[dim] = __float2half_rn(acc[i] * inv_l); + } +} + +template +__device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel( + half *out_, + const half *q_, + const half *k_cache_, + const half *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + (void)max_num_blocks_per_seq; + + constexpr int kWarpSize = 32; + constexpr int kWarps = 8; + constexpr int kHeadDim = 128; + // Extra padding in the K dimension to reduce shared-memory bank conflicts for ldmatrix / wmma loads. + // NOTE: FA2 uses a swizzled smem layout; padding is a smaller step that keeps our code simple. + constexpr int kHeadDimSmem = 136; // must be a multiple of 8 for wmma::load_matrix_sync + constexpr int kBlockM = 16; // 2 rows per warp + // Keep static shared memory <= 48KB for compatibility with build targets that cap SMEM at 0xC000. + // kBlockN=64 brings s_q+s_k+s_v+s_scores+s_phys/s_off down to ~41KB. + constexpr int kBlockN = 64; + constexpr int kDimsPerThread = kHeadDim / kWarpSize; + + static_assert(kHeadDim % kWarpSize == 0, "head_dim must be divisible by 32."); + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= kWarps) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * kBlockM; + // Uniform early return for empty tail tiles (avoid deadlock with __syncthreads()). + if (m_start >= q_len) { + return; + } + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + + // Clamp max k length for this CTA based on the last active query row in the tile. + const int max_q_in_tile = min(m_start + kBlockM, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + const float alibi_slope_log2 = alibi_slope * kLog2e; + + const int pbs = static_cast(page_block_size); + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + // Shared memory: + // - s_q: [kBlockM, kHeadDimSmem] (padded) + // - s_k/s_v: [kBlockN, kHeadDim] + // - s_scores: [kBlockM, kBlockN] raw dot products (no scale / alibi) + __shared__ __align__(16) half s_q[kBlockM * kHeadDimSmem]; + __shared__ int32_t s_phys[kBlockN]; + __shared__ int32_t s_off[kBlockN]; + __shared__ __align__(16) half s_k[kBlockN * kHeadDimSmem]; + __shared__ __align__(16) half s_v[kBlockN * kHeadDimSmem]; + __shared__ __align__(16) float s_scores[kBlockM * kBlockN]; + + // Load Q tile (pad inactive rows with 0). + for (int idx = threadIdx.x; idx < kBlockM * kHeadDim; idx += blockDim.x) { + const int r = idx / kHeadDim; + const int d = idx - r * kHeadDim; + const int q_token_local = m_start + r; + if (q_token_local < q_len) { + const int64_t q_token = q_start + static_cast(q_token_local); + const half *q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + s_q[r * kHeadDimSmem + d] = q_ptr[d]; + } else { + s_q[r * kHeadDimSmem + d] = __float2half_rn(0.0f); + } + } + __syncthreads(); + + // Two rows per warp: row0=warp_id, row1=warp_id+kWarps. + const int row0 = warp_id; + const int row1 = warp_id + kWarps; + const bool active0 = (row0 < kBlockM) && ((m_start + row0) < q_len); + const bool active1 = (row1 < kBlockM) && ((m_start + row1) < q_len); + const int allowed0 = active0 ? min(history_len + (m_start + row0) + 1, kv_len_total) : 0; + const int allowed1 = active1 ? min(history_len + (m_start + row1) + 1, kv_len_total) : 0; + + float m0 = -INFINITY, l0 = 0.0f; + float m1 = -INFINITY, l1 = 0.0f; + float acc0[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f}; + float acc1[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f}; + + // Iterate over K/V tiles. + for (int k_base = 0; k_base < max_allowed_k_len; k_base += kBlockN) { + // Map logical k positions to physical blocks for this tile (pad the tail with -1). + for (int t = threadIdx.x; t < kBlockN; t += blockDim.x) { + const int kpos = k_base + t; + if (kpos < max_allowed_k_len) { + const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs); + const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs); + s_phys[t] = static_cast(block_table[page]); + s_off[t] = off; + } else { + s_phys[t] = -1; + s_off[t] = 0; + } + } + __syncthreads(); + + // Load K/V tile into shared memory (pad with 0 for inactive tokens). + for (int idx = threadIdx.x; idx < kBlockN * kHeadDim; idx += blockDim.x) { + const int t = idx / kHeadDim; + const int d = idx - t * kHeadDim; + const int32_t phys = s_phys[t]; + if (phys >= 0) { + const int32_t off = s_off[t]; + const half *k_ptr = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + const half *v_ptr = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + s_k[t * kHeadDimSmem + d] = k_ptr[d]; + s_v[t * kHeadDimSmem + d] = v_ptr[d]; + } else { + s_k[t * kHeadDimSmem + d] = __float2half_rn(0.0f); + s_v[t * kHeadDimSmem + d] = __float2half_rn(0.0f); + } + } + __syncthreads(); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + // WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows. + // For kBlockN=64, only the first 4 warps participate in WMMA score computation. + namespace wmma = nvcuda::wmma; + constexpr int kNSub = kBlockN / 16; + if (warp_id < kNSub) { + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + const int n_sub = warp_id; // [0, kNSub) + const half *q_tile = s_q; + const half *k_tile = s_k + (n_sub * 16) * kHeadDimSmem; + // K loop (head_dim=128). +#pragma unroll + for (int kk = 0; kk < (kHeadDim / 16); ++kk) { + wmma::load_matrix_sync(a_frag, q_tile + kk * 16, kHeadDimSmem); + wmma::load_matrix_sync(b_frag, k_tile + kk * 16, kHeadDimSmem); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + float *scores_tile = s_scores + n_sub * 16; + wmma::store_matrix_sync(scores_tile, c_frag, kBlockN, wmma::mem_row_major); + } +#else + // No WMMA support on this architecture: fall back to scalar dot in the existing kernels. + // (We keep scores as 0 so this kernel is effectively incorrect; host dispatch must avoid selecting it.) + if (threadIdx.x == 0) { + // Intentionally empty. + } +#endif + __syncthreads(); + + // Online softmax + V update per row handled by the same warp across tiles. + if (row0 < kBlockM) { + PagedAttentionPrefillMmaScoreUpdateRow( + lane, k_base, allowed0, s_scores + row0 * kBlockN, s_v, scale_log2, alibi_slope_log2, m0, l0, acc0); + } + if (row1 < kBlockM) { + PagedAttentionPrefillMmaScoreUpdateRow( + lane, k_base, allowed1, s_scores + row1 * kBlockN, s_v, scale_log2, alibi_slope_log2, m1, l1, acc1); + } + __syncthreads(); + } + + // Write outputs. + if (row0 < kBlockM) { + PagedAttentionPrefillMmaScoreWriteRow( + lane, active0, m_start + row0, q_start, head_idx, out_, o_stride, o_head_stride, l0, acc0); + } + if (row1 < kBlockM) { + PagedAttentionPrefillMmaScoreWriteRow( + lane, active1, m_start + row1, q_start, head_idx, out_, o_stride, o_head_stride, l1, acc1); + } +} + +} // namespace op::paged_attention_prefill::cuda + +#endif diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h index 6f1809f06..a40f4ceaf 100644 --- a/src/infiniop/ops/paged_attention_prefill/info.h +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -3,6 +3,7 @@ #include "../../../utils.h" #include "../../tensor.h" +#include #include #include #include @@ -14,21 +15,30 @@ class PagedAttentionPrefillInfo { public: infiniDtype_t dtype; + infiniDtype_t index_dtype; float scale; size_t num_seqs; + size_t total_q_tokens; size_t num_heads; size_t num_kv_heads; size_t head_size; - size_t block_size; + size_t page_block_size; size_t max_num_blocks_per_seq; - size_t total_q_tokens; + size_t num_blocks; ptrdiff_t q_stride; ptrdiff_t q_head_stride; - ptrdiff_t kv_block_stride; - ptrdiff_t kv_head_stride; + ptrdiff_t k_batch_stride; + ptrdiff_t k_row_stride; + ptrdiff_t k_head_stride; + ptrdiff_t v_batch_stride; + ptrdiff_t v_row_stride; + ptrdiff_t v_head_stride; ptrdiff_t o_stride; + ptrdiff_t o_head_stride; + + ptrdiff_t block_table_batch_stride; static utils::Result create( infiniopTensorDescriptor_t out_desc, @@ -36,89 +46,161 @@ class PagedAttentionPrefillInfo { infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, - infiniopTensorDescriptor_t cum_seq_lens_q_desc, + infiniopTensorDescriptor_t total_kv_lens_desc, + infiniopTensorDescriptor_t cum_seqlens_q_desc, const std::optional &alibi_slopes_desc, float scale) { auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); - + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) { + // q/out: [total_q, heads, head_dim] + if (q_desc->ndim() != 3 || out_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + // FA2 paged KV layout: [num_blocks, page_block_size, kv_heads, head_dim] + if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (block_tables_desc->ndim() != 2 || total_kv_lens_desc->ndim() != 1 || cum_seqlens_q_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + // Index dtypes: allow I32/I64/U32 (v0.4 roadmap allows internal conversion to I32). + const auto block_tables_dt = block_tables_desc->dtype(); + if (!((block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32))) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } + // Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now + // (matches current paged_attention_prefill signature). We will convert to int32 internally later. + if (total_kv_lens_desc->dtype() != INFINI_DTYPE_I64 || cum_seqlens_q_desc->dtype() != INFINI_DTYPE_I64) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) { + if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (alibi_slopes_desc.value()->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); } - auto k_shape = k_cache_desc->shape(); - auto v_shape = v_cache_desc->shape(); - auto block_tables_shape = block_tables_desc->shape(); - auto seq_lens_shape = seq_lens_desc->shape(); - auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape(); + const auto q_shape = q_desc->shape(); + const auto k_shape = k_cache_desc->shape(); + + const size_t total_q_tokens = q_shape[0]; + const size_t num_heads = q_shape[1]; + const size_t head_size = q_shape[2]; + + const size_t num_blocks = k_shape[0]; + const size_t page_block_size = k_shape[2]; + const size_t num_kv_heads = k_shape[1]; - if (k_shape.size() != 4 || v_shape.size() != 4) { + if (head_size != 64 && head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_heads % num_kv_heads != 0) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (block_tables_shape.size() != 2) { + // v_cache must match the inferred K layout. + const auto v_shape = v_cache_desc->shape(); + if (v_shape[0] != num_blocks || v_shape[3] != head_size) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (seq_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) { + if (v_shape[1] != num_kv_heads || v_shape[2] != page_block_size) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (cum_seq_lens_q_shape[0] != seq_lens_shape[0] + 1) { - return INFINI_STATUS_BAD_PARAM; + if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[3] != k_shape[3]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; } - // Q shape: [total_tokens, heads, dim] - auto q_shape = q_desc->shape(); - if (q_shape.size() != 3) { + if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - size_t total_q_tokens = q_shape[0]; - size_t num_heads = q_shape[1]; - size_t head_size = q_shape[2]; - if (head_size > 1024) { + const size_t num_seqs = total_kv_lens_desc->shape()[0]; + if (cum_seqlens_q_desc->shape()[0] != num_seqs + 1) { return INFINI_STATUS_BAD_PARAM; } - size_t num_seqs = seq_lens_shape[0]; - size_t num_kv_heads = k_shape[1]; - size_t block_size = k_shape[2]; - size_t max_num_blocks_per_seq = block_tables_shape[1]; - - ptrdiff_t q_stride = q_desc->stride(0); - ptrdiff_t q_head_stride = q_desc->stride(1); - ptrdiff_t kv_block_stride = k_cache_desc->stride(0); - ptrdiff_t kv_head_stride = k_cache_desc->stride(1); - ptrdiff_t o_stride = out_desc->stride(0); + const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + + // Strides (in elements) + const ptrdiff_t q_stride = q_desc->stride(0); + const ptrdiff_t q_head_stride = q_desc->stride(1); + const ptrdiff_t o_stride = out_desc->stride(0); + const ptrdiff_t o_head_stride = out_desc->stride(1); + + const ptrdiff_t k_batch_stride = k_cache_desc->stride(0); + const ptrdiff_t k_row_stride = k_cache_desc->stride(2); + const ptrdiff_t k_head_stride = k_cache_desc->stride(1); + + const ptrdiff_t v_batch_stride = v_cache_desc->stride(0); + const ptrdiff_t v_row_stride = v_cache_desc->stride(2); + const ptrdiff_t v_head_stride = v_cache_desc->stride(1); + + const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0); + + if (const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_INFO")) { + static bool printed = false; + if (!printed && std::strcmp(dbg, "1") == 0) { + const auto bt_shape = block_tables_desc->shape(); + std::fprintf(stderr, + "[infiniop][flash_attention_prefill][info] k_shape=[%zu,%zu,%zu,%zu] k_strides=[%td,%td,%td,%td] (row_stride=%td head_stride=%td)\n", + static_cast(k_shape[0]), static_cast(k_shape[1]), + static_cast(k_shape[2]), static_cast(k_shape[3]), + k_cache_desc->stride(0), k_cache_desc->stride(1), k_cache_desc->stride(2), k_cache_desc->stride(3), + k_row_stride, k_head_stride); + std::fprintf(stderr, + "[infiniop][flash_attention_prefill][info] block_tables shape=[%zu,%zu] strides=[%td,%td]\n", + static_cast(bt_shape[0]), static_cast(bt_shape[1]), + block_tables_desc->stride(0), block_tables_desc->stride(1)); + printed = true; + } + } return utils::Result(PagedAttentionPrefillInfo{ dtype, + block_tables_dt, scale, num_seqs, + total_q_tokens, num_heads, num_kv_heads, head_size, - block_size, + page_block_size, max_num_blocks_per_seq, - total_q_tokens, + num_blocks, q_stride, q_head_stride, - kv_block_stride, - kv_head_stride, - o_stride}); + k_batch_stride, + k_row_stride, + k_head_stride, + v_batch_stride, + v_row_stride, + v_head_stride, + o_stride, + o_head_stride, + block_table_batch_stride, + }); } }; - } // namespace op::paged_attention_prefill #endif diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu index 90c4c94fc..e95268a84 100644 --- a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -1,56 +1,1237 @@ -#include -#include -#include -#include +#include + +#include +#include +#include +#include #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" -#include "../cuda/kernel.cuh" + +// #include "paged_attention_prefill_fa2.cuh" #include "paged_attention_prefill_nvidia.cuh" -template -infiniStatus_t launchPagedAttentionPrefill( - Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, - const int64_t *block_tables, - const int64_t *seq_lens, - const int64_t *cum_seq_lens_q, - const float *alibi_slopes, - const size_t num_heads, - const size_t num_seqs, - const size_t num_kv_heads, - const float scale, - const size_t max_num_blocks_per_seq, - const size_t block_size, - const size_t total_q_tokens, - const size_t head_size, - const ptrdiff_t kv_block_stride, - const ptrdiff_t kv_head_stride, - const ptrdiff_t q_stride, - const ptrdiff_t q_head_stride, +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention_prefill::nvidia { + +namespace { +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) { + // Heuristic auto-dispatch (v0.4): + // - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256. + // - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80). + // + // Users can always override via INFINIOP_FLASH_PREFILL_KERNEL. + if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) { + if (info.head_size == 128) { + return "warpcta8pipe"; + } + // For head_size=64 we keep the previous default until we have broader perf coverage. + } + return "warpcta8"; +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel). + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride, + q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel). + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride, + q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 4 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 4 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8N128( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages. + // Note: we keep K in shared memory but load V from global to stay within the per-block shared limit. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, with cp.async pipelining. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Mma( + half *out, + const half *q, + const half *k_cache, + const half *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, with cp.async pipelining. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + size_t total_q_tokens, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + // Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch: + // blockIdx.z in [0, num_splits * num_m_blocks). + const int num_m_blocks = static_cast((total_q_tokens + 8 - 1) / 8); + const int bz = static_cast(blockIdx.z); + const int split_idx = bz / num_m_blocks; + const int m_block = bz - split_idx * num_m_blocks; + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + size_t total_q_tokens, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + const int num_m_blocks = static_cast((total_q_tokens + 8 - 1) / 8); + const int bz = static_cast(blockIdx.z); + const int split_idx = bz / num_m_blocks; + const int m_block = bz - split_idx * num_m_blocks; + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 16 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 16 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +infiniStatus_t launch_prefill_ref( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, cudaStream_t stream) { - if (total_q_tokens == 0 || num_heads == 0) { + const dim3 grid(static_cast(total_q_tokens), static_cast(num_heads), 1); + const dim3 block(static_cast(head_size), 1, 1); + + if (head_size == 64) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, num_seqs); + return INFINI_STATUS_SUCCESS; + } + + if (head_size == 128) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, num_seqs); + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_SHAPE; +} + +template +infiniStatus_t launch_prefill_warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + const dim3 block(32, 1, 1); + // Global-token launch: + // - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch + // - matches PagedAttention varlen (cu_seqlens) mental model better + const dim3 grid(static_cast(num_heads), + static_cast(total_q_tokens), + 1); + + switch (head_size) { + case 64: + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq, + page_block_size, block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq, + page_block_size, block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: return INFINI_STATUS_BAD_TENSOR_SHAPE; } +} + +template +infiniStatus_t launch_prefill( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 4; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); - dim3 grid(total_q_tokens, num_heads); - dim3 block(head_size); + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta8 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} - op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel +template +infiniStatus_t launch_prefill_warpcta8pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8Pipe + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta8Pipe + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8mma( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + // Current WMMA kernel only supports fp16 + head_dim=128. + if constexpr (!std::is_same_v) { + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + + if (head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts. + // Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1. + const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE"); + const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0); + const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size; + if (!force_mma && seqlen_k_est > 4096) { + static bool warned = false; + if (!warned) { + std::fprintf(stderr, + "[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). " + "Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n", + seqlen_k_est); + warned = true; + } + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + + // WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel. + int device = 0; + cudaDeviceProp prop{}; + if (cudaGetDevice(&device) == cudaSuccess && cudaGetDeviceProperties(&prop, device) == cudaSuccess) { + if (prop.major < 7) { + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + } + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(16)))); + + PagedAttentionPrefillHd128WarpCta8Mma <<>>( - out, q, k_cache, v_cache, - block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, - num_heads, num_kv_heads, scale, - max_num_blocks_per_seq, block_size, - kv_block_stride, kv_head_stride, + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, - head_size, - num_seqs); + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; +} +template +infiniStatus_t launch_prefill_warpcta8pipe_splitkv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kMaxSplits = 8; + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast(kWarps)); + // Single kernel launch with split_idx encoded in grid.z: + // blockIdx.z in [0, num_splits * num_m_blocks). + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(num_m_blocks * static_cast(num_splits))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8PipeSplitKv + <<>>( + partial_acc, partial_m, partial_l, num_splits, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); + break; + case 128: + PagedAttentionPrefillHd128WarpCta8PipeSplitKv + <<>>( + partial_acc, partial_m, partial_l, num_splits, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); + break; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Combine: one warp per (token, head). + const dim3 block2(32); + const dim3 grid2(static_cast(num_heads), static_cast(total_q_tokens), 1); + switch (head_size) { + case 64: + PagedAttentionPrefillHd64SplitKvCombine + <<>>( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128SplitKvCombine + <<>>( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8n128( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + // Only meaningful for head_dim=128. + if (head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + PagedAttentionPrefillHd128WarpCta8N128 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); return INFINI_STATUS_SUCCESS; } -namespace op::paged_attention_prefill::nvidia { +template +infiniStatus_t launch_prefill_warpcta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 16; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta16 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta16 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} +} // namespace struct Descriptor::Opaque { std::shared_ptr internal; @@ -68,22 +1249,87 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, - infiniopTensorDescriptor_t cum_seq_lens_q_desc, + infiniopTensorDescriptor_t total_kv_lens_desc, + infiniopTensorDescriptor_t cum_seqlens_q_desc, const std::optional &alibi_slopes_desc, float scale) { auto info = PagedAttentionPrefillInfo::create( out_desc, q_desc, k_cache_desc, v_cache_desc, - block_tables_desc, seq_lens_desc, - cum_seq_lens_q_desc, + block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc, alibi_slopes_desc, scale); - CHECK_RESULT(info); + // Optional split-kv prefill requires workspace for partial (m, l, acc). + // IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve + // a huge workspace unless the user explicitly enables split-kv. + bool use_splitkv = false; + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { + use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 1; + if (use_splitkv) { + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) { + const int v = std::atoi(env); + if (v > 0) { + num_splits = v; + } + } else { + num_splits = 4; + } + constexpr int kMaxSplits = 8; + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + } + const size_t n = info->total_q_tokens * info->num_heads; + const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0; + + // FA2-style kernel needs a workspace scratch for: + // - converting block_tables + total_kv_lens to int32 + // - storing softmax LSE (only required to satisfy the upstream kernel contract) + // bool want_fa2 = false; + // if (const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL")) { + // want_fa2 = (std::strcmp(k_env, "fa2") == 0); + // } + // bool fa2_materialize_kv = false; + // if (const char *env = std::getenv("INFINIOP_FA2_MATERIALIZE_PAGED_KV")) { + // fa2_materialize_kv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + // } + // size_t fa2_workspace_bytes = 0; + // // FA2 prefill supports both fp16 and bf16 inputs (head_dim=128, block_size=256). + // // Workspace sizing is identical since both are 16-bit element types. + // if (want_fa2 && (info->dtype == INFINI_DTYPE_F16 || info->dtype == INFINI_DTYPE_BF16) && + // info->head_size == 128 && info->page_block_size == 256) { + // const size_t bt_bytes = info->num_seqs * info->max_num_blocks_per_seq * sizeof(int); + // const size_t len_bytes = info->num_seqs * sizeof(int); + // const size_t cuq_bytes = (info->num_seqs + 1) * sizeof(int); + // const size_t cuk_bytes = (info->num_seqs + 1) * sizeof(int); + // const size_t lse_bytes = info->num_heads * info->total_q_tokens * sizeof(float); + // // Add a small alignment slack since we sub-allocate with alignment. + // fa2_workspace_bytes = bt_bytes + len_bytes + cuq_bytes + cuk_bytes + lse_bytes + 64; + + // // Optional: materialize paged KV into the FA2-friendly physical layout + // // [num_blocks, page_block_size, kv_heads, head_dim] (token-major) to avoid + // // extremely strided loads when the framework stores KV as + // // [num_blocks, kv_heads, page_block_size, head_dim] (head-major). + // if (fa2_materialize_kv) { + // // Materialize per-seq contiguous KV in *sequence order*: + // // [num_seqs, max_num_blocks_per_seq * page_block_size, kv_heads, head_dim]. + // const size_t kv_elems = + // info->num_seqs * info->max_num_blocks_per_seq * info->page_block_size * info->num_kv_heads * info->head_size; + // const size_t kv_bytes = kv_elems * sizeof(uint16_t); // 16-bit (fp16/bf16) + // // K + V + alignment slack + // fa2_workspace_bytes += 2 * kv_bytes + 64; + // } + // } + + const size_t workspace_bytes = splitkv_workspace_bytes; + // const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes; + *desc_ptr = new Descriptor( new Opaque{reinterpret_cast(handle)->internal()}, - info.take(), 0, handle->device, handle->device_id); + info.take(), workspace_bytes, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -92,35 +1338,379 @@ infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *out, const void *q, const void *k_cache, const void *v_cache, const void *block_tables, - const void *seq_lens, - const void *cum_seq_lens_q, + const void *total_kv_lens, + const void *cum_seqlens_q, const void *alibi_slopes, void *stream_) const { + auto stream = static_cast(stream_); + + const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast(alibi_slopes); + const auto *total_kv_lens_i64 = static_cast(total_kv_lens); + const auto *cu_seqlens_q_i64 = static_cast(cum_seqlens_q); + + bool use_splitkv = false; + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { + use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 1; + if (use_splitkv) { + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) { + const int v = std::atoi(env); + if (v > 0) { + num_splits = v; + } + } else { + // Conservative default; users can override. + num_splits = 4; + } + constexpr int kMaxSplits = 8; + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + const size_t n = _info.total_q_tokens * _info.num_heads; + const size_t required = static_cast(num_splits) * n * (_info.head_size + 2) * sizeof(float); + if (workspace_size < required) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + } + + if (use_splitkv) { + const size_t n = _info.total_q_tokens * _info.num_heads; + float *partial_acc = static_cast(workspace); + float *partial_m = partial_acc + static_cast(num_splits) * n * _info.head_size; + float *partial_l = partial_m + static_cast(num_splits) * n; + + // Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64. +#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \ + return launch_prefill_warpcta8pipe_splitkv( \ + partial_acc, partial_m, partial_l, num_splits, \ + static_cast(out), \ + static_cast(q), \ + static_cast(k_cache), \ + static_cast(v_cache), \ + static_cast(BT_PTR), \ + total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream) + + if (_info.dtype == INFINI_DTYPE_F16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_SPLITKV(int64_t, half, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_SPLITKV(int32_t, half, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_SPLITKV(uint32_t, half, block_tables); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (_info.dtype == INFINI_DTYPE_BF16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + +#undef DISPATCH_SPLITKV + } + +// Default to the fastest validated kernel for supported shapes. +// "ref" is still available for debugging/correctness bisecting. +#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \ + do { \ + const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \ + const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \ + if (k_env != nullptr) { \ + const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \ + if (!known) { \ + const char *fallback = default_prefill_kernel(_info); \ + std::fprintf(stderr, \ + "[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \ + k, fallback); \ + k = fallback; \ + } \ + } \ + const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \ + static bool printed_dispatch = false; \ + if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \ + std::fprintf(stderr, \ + "[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \ + k, \ + (k_env == nullptr ? "auto" : "env"), \ + static_cast(_info.head_size), \ + static_cast(_info.page_block_size), \ + static_cast(_info.dtype)); \ + printed_dispatch = true; \ + } \ + if (std::strcmp(k, "warp") == 0) { \ + return launch_prefill_warp( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta") == 0) { \ + return launch_prefill( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta8") == 0) { \ + return launch_prefill_warpcta8( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta8pipe") == 0) { \ + return launch_prefill_warpcta8pipe( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if constexpr (std::is_same_v) { \ + if (std::strcmp(k, "warpcta8mma") == 0) { \ + return launch_prefill_warpcta8mma( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + } \ + if (std::strcmp(k, "warpcta8n128") == 0) { \ + return launch_prefill_warpcta8n128( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta16") == 0) { \ + return launch_prefill_warpcta16( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "ref") == 0) { \ + return launch_prefill_ref( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + return INFINI_STATUS_BAD_PARAM; \ + } while (false) - cudaStream_t stream = (cudaStream_t)stream_; - -#define LAUNCH_KERNEL(Tdata, Tcompute) \ - launchPagedAttentionPrefill( \ - (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ - (const float *)alibi_slopes, \ - _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ - _info.scale, _info.max_num_blocks_per_seq, \ - _info.block_size, _info.total_q_tokens, \ - _info.head_size, \ - _info.kv_block_stride, _info.kv_head_stride, \ - _info.q_stride, _info.q_head_stride, \ - stream) - - if (_info.dtype == INFINI_DTYPE_F16) { - return LAUNCH_KERNEL(half, float); - } else if (_info.dtype == INFINI_DTYPE_BF16) { - return LAUNCH_KERNEL(__nv_bfloat16, float); - } else if (_info.dtype == INFINI_DTYPE_F32) { - return LAUNCH_KERNEL(float, float); +#define DISPATCH_INDEX(Tindex) \ + do { \ + if (_info.dtype == INFINI_DTYPE_F16) { \ + DISPATCH_KERNEL(Tindex, half, float); \ + } \ + if (_info.dtype == INFINI_DTYPE_BF16) { \ + DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \ + } \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } while (false) + + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_INDEX(int64_t); + } else if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_INDEX(int32_t); + } else if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_INDEX(uint32_t); } return INFINI_STATUS_BAD_TENSOR_DTYPE; } } // namespace op::paged_attention_prefill::nvidia + +// #include +// #include +// #include +// #include + +// #include "../../../devices/nvidia/nvidia_common.cuh" +// #include "../../../devices/nvidia/nvidia_kernel_common.cuh" +// #include "../cuda/kernel.cuh" +// #include "paged_attention_prefill_nvidia.cuh" + +// template +// infiniStatus_t launchPagedAttentionPrefill( +// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, +// const int64_t *block_tables, +// const int64_t *seq_lens, +// const int64_t *cum_seq_lens_q, +// const float *alibi_slopes, +// const size_t num_heads, +// const size_t num_seqs, +// const size_t num_kv_heads, +// const float scale, +// const size_t max_num_blocks_per_seq, +// const size_t block_size, +// const size_t total_q_tokens, +// const size_t head_size, +// const ptrdiff_t kv_block_stride, +// const ptrdiff_t kv_head_stride, +// const ptrdiff_t q_stride, +// const ptrdiff_t q_head_stride, +// cudaStream_t stream) { + +// if (total_q_tokens == 0 || num_heads == 0) { +// return INFINI_STATUS_BAD_TENSOR_SHAPE; +// } + +// dim3 grid(total_q_tokens, num_heads); +// dim3 block(head_size); + +// op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel +// <<>>( +// out, q, k_cache, v_cache, +// block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, +// num_heads, num_kv_heads, scale, +// max_num_blocks_per_seq, block_size, +// kv_block_stride, kv_head_stride, +// q_stride, q_head_stride, +// head_size, +// num_seqs); + +// return INFINI_STATUS_SUCCESS; +// } + +// namespace op::paged_attention_prefill::nvidia { + +// struct Descriptor::Opaque { +// std::shared_ptr internal; +// }; + +// Descriptor::~Descriptor() { +// delete _opaque; +// } + +// infiniStatus_t Descriptor::create( +// infiniopHandle_t handle, +// Descriptor **desc_ptr, +// infiniopTensorDescriptor_t out_desc, +// infiniopTensorDescriptor_t q_desc, +// infiniopTensorDescriptor_t k_cache_desc, +// infiniopTensorDescriptor_t v_cache_desc, +// infiniopTensorDescriptor_t block_tables_desc, +// infiniopTensorDescriptor_t seq_lens_desc, +// infiniopTensorDescriptor_t cum_seq_lens_q_desc, +// const std::optional &alibi_slopes_desc, +// float scale) { + +// auto info = PagedAttentionPrefillInfo::create( +// out_desc, q_desc, k_cache_desc, v_cache_desc, +// block_tables_desc, seq_lens_desc, +// cum_seq_lens_q_desc, +// alibi_slopes_desc, scale); + +// CHECK_RESULT(info); + +// *desc_ptr = new Descriptor( +// new Opaque{reinterpret_cast(handle)->internal()}, +// info.take(), 0, handle->device, handle->device_id); + +// return INFINI_STATUS_SUCCESS; +// } + +// infiniStatus_t Descriptor::calculate( +// void *workspace, size_t workspace_size, +// void *out, const void *q, const void *k_cache, const void *v_cache, +// const void *block_tables, +// const void *seq_lens, +// const void *cum_seq_lens_q, +// const void *alibi_slopes, +// void *stream_) const { + +// cudaStream_t stream = (cudaStream_t)stream_; + +// #define LAUNCH_KERNEL(Tdata, Tcompute) \ +// launchPagedAttentionPrefill( \ +// (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ +// (const float *)alibi_slopes, \ +// _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ +// _info.scale, _info.max_num_blocks_per_seq, \ +// _info.block_size, _info.total_q_tokens, \ +// _info.head_size, \ +// _info.kv_block_stride, _info.kv_head_stride, \ +// _info.q_stride, _info.q_head_stride, \ +// stream) + +// if (_info.dtype == INFINI_DTYPE_F16) { +// return LAUNCH_KERNEL(half, float); +// } else if (_info.dtype == INFINI_DTYPE_BF16) { +// return LAUNCH_KERNEL(__nv_bfloat16, float); +// } else if (_info.dtype == INFINI_DTYPE_F32) { +// return LAUNCH_KERNEL(float, float); +// } + +// return INFINI_STATUS_BAD_TENSOR_DTYPE; +// } + +// } // namespace op::paged_attention_prefill::nvidia diff --git a/test/infiniop/paged_attention.py b/test/infiniop/paged_attention.py index 882e9cfee..c1f10f9b7 100644 --- a/test/infiniop/paged_attention.py +++ b/test/infiniop/paged_attention.py @@ -100,13 +100,12 @@ def ref_single_query_cached_kv_attention( ] # Data types for testing -_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32] +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] # Tolerance map for different data types _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, - InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, } # Global flags for controlling test behavior diff --git a/test/infiniop/paged_attention_prefill.py b/test/infiniop/paged_attention_prefill.py index 4bbe762a8..65d843fae 100644 --- a/test/infiniop/paged_attention_prefill.py +++ b/test/infiniop/paged_attention_prefill.py @@ -32,10 +32,9 @@ (16, 128, 128, 128, 8, 16, 4), ] -_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] _TOLERANCE_MAP = { - InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2}, } From 4cd1f6881bf67dd14183c18307bacbca3612c2c7 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 02:21:52 +0000 Subject: [PATCH 21/25] issue/979 - removed commented paged attn codes --- .../nvidia/paged_attention_nvidia.cu | 148 ---------------- .../nvidia/paged_attention_prefill_nvidia.cu | 166 ------------------ 2 files changed, 314 deletions(-) diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu index 18b6ef073..a4bf82732 100644 --- a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu @@ -212,151 +212,3 @@ infiniStatus_t Descriptor::calculate( } } // namespace op::paged_attention::nvidia - -// #include - -// #include "../../../devices/nvidia/nvidia_common.cuh" -// #include "../../../devices/nvidia/nvidia_kernel_common.cuh" - -// #include "../../../reduce/cuda/reduce.cuh" -// #include "../cuda/kernel.cuh" -// #include "paged_attention_nvidia.cuh" - -// template -// INFINIOP_CUDA_KERNEL pagedAttention( -// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, -// const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes, -// const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, -// const size_t block_size, -// const ptrdiff_t q_stride, -// const ptrdiff_t kv_block_stride, -// const ptrdiff_t kv_head_stride, -// const ptrdiff_t o_stride) { -// op::paged_attention::cuda::pagedAttentionKernel( -// out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale, -// max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride); -// } - -// namespace op::paged_attention::nvidia { - -// struct Descriptor::Opaque { -// std::shared_ptr internal; -// }; - -// Descriptor::~Descriptor() { -// delete _opaque; -// } - -// infiniStatus_t Descriptor::create( -// infiniopHandle_t handle, -// Descriptor **desc_ptr, -// infiniopTensorDescriptor_t out_desc, -// infiniopTensorDescriptor_t q_desc, -// infiniopTensorDescriptor_t k_cache_desc, -// infiniopTensorDescriptor_t v_cache_desc, -// infiniopTensorDescriptor_t block_tables_desc, -// infiniopTensorDescriptor_t seq_lens_desc, -// const std::optional &alibi_slopes_desc, -// float scale) { -// auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); -// CHECK_RESULT(info); -// *desc_ptr = new Descriptor( -// new Opaque{reinterpret_cast(handle)->internal()}, -// info.take(), 0, handle->device, handle->device_id); - -// return INFINI_STATUS_SUCCESS; -// } - -// template -// infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache, -// infiniDtype_t dtype, -// const void *block_tables, const void *seq_lens, const void *alibi_slopes, -// size_t num_heads, size_t num_seqs, -// size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, -// ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride, -// cudaStream_t stream) { -// dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1); -// dim3 block(NUM_THREADS); -// size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); - -// if (dtype == INFINI_DTYPE_F16) { -// pagedAttention -// <<>>( -// (half *)out, -// (const half *)q, (const half *)k_cache, (const half *)v_cache, -// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, -// scale, max_num_blocks_per_seq, block_size, -// q_stride, kv_block_stride, kv_head_stride, o_stride); -// } else if (dtype == INFINI_DTYPE_BF16) { -// pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS> -// <<>>( -// (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache, -// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, -// scale, max_num_blocks_per_seq, block_size, -// q_stride, kv_block_stride, kv_head_stride, o_stride); -// } else if (dtype == INFINI_DTYPE_F32) { -// pagedAttention -// <<>>( -// (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache, -// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, -// scale, max_num_blocks_per_seq, block_size, -// q_stride, kv_block_stride, kv_head_stride, o_stride); -// } else { -// return INFINI_STATUS_BAD_TENSOR_DTYPE; -// } -// return INFINI_STATUS_SUCCESS; -// } - -// infiniStatus_t Descriptor::calculate( -// void *workspace, size_t workspace_size, -// void *out, const void *q, const void *k_cache, const void *v_cache, -// const void *block_tables, const void *seq_lens, const void *alibi_slopes, -// void *stream_) const { -// cudaStream_t stream = (cudaStream_t)stream_; - -// #define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \ -// launchKernel<__H_SIZE, __B_SIZE>( \ -// out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \ -// _info.num_heads, _info.num_seqs, \ -// _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \ -// _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \ -// stream); - -// #define SWITCH_HEAD_SIZE(__B_SIZE) \ -// switch (_info.head_size) { \ -// case 16: \ -// LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \ -// break; \ -// case 32: \ -// LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \ -// break; \ -// case 64: \ -// LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \ -// break; \ -// case 128: \ -// LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \ -// break; \ -// case 256: \ -// LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \ -// break; \ -// default: \ -// return INFINI_STATUS_BAD_TENSOR_SHAPE; \ -// } - -// if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { -// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024) -// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { -// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512) -// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { -// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096) -// } else { -// return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; -// } - -// #undef LAUNCH_HEADSIZE_BLOCKSIZE -// #undef SWITCH_HEAD_SIZE - -// return INFINI_STATUS_SUCCESS; -// } - -// } // namespace op::paged_attention::nvidia diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu index e95268a84..b8e98338a 100644 --- a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -1285,45 +1285,6 @@ infiniStatus_t Descriptor::create( const size_t n = info->total_q_tokens * info->num_heads; const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0; - // FA2-style kernel needs a workspace scratch for: - // - converting block_tables + total_kv_lens to int32 - // - storing softmax LSE (only required to satisfy the upstream kernel contract) - // bool want_fa2 = false; - // if (const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL")) { - // want_fa2 = (std::strcmp(k_env, "fa2") == 0); - // } - // bool fa2_materialize_kv = false; - // if (const char *env = std::getenv("INFINIOP_FA2_MATERIALIZE_PAGED_KV")) { - // fa2_materialize_kv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); - // } - // size_t fa2_workspace_bytes = 0; - // // FA2 prefill supports both fp16 and bf16 inputs (head_dim=128, block_size=256). - // // Workspace sizing is identical since both are 16-bit element types. - // if (want_fa2 && (info->dtype == INFINI_DTYPE_F16 || info->dtype == INFINI_DTYPE_BF16) && - // info->head_size == 128 && info->page_block_size == 256) { - // const size_t bt_bytes = info->num_seqs * info->max_num_blocks_per_seq * sizeof(int); - // const size_t len_bytes = info->num_seqs * sizeof(int); - // const size_t cuq_bytes = (info->num_seqs + 1) * sizeof(int); - // const size_t cuk_bytes = (info->num_seqs + 1) * sizeof(int); - // const size_t lse_bytes = info->num_heads * info->total_q_tokens * sizeof(float); - // // Add a small alignment slack since we sub-allocate with alignment. - // fa2_workspace_bytes = bt_bytes + len_bytes + cuq_bytes + cuk_bytes + lse_bytes + 64; - - // // Optional: materialize paged KV into the FA2-friendly physical layout - // // [num_blocks, page_block_size, kv_heads, head_dim] (token-major) to avoid - // // extremely strided loads when the framework stores KV as - // // [num_blocks, kv_heads, page_block_size, head_dim] (head-major). - // if (fa2_materialize_kv) { - // // Materialize per-seq contiguous KV in *sequence order*: - // // [num_seqs, max_num_blocks_per_seq * page_block_size, kv_heads, head_dim]. - // const size_t kv_elems = - // info->num_seqs * info->max_num_blocks_per_seq * info->page_block_size * info->num_kv_heads * info->head_size; - // const size_t kv_bytes = kv_elems * sizeof(uint16_t); // 16-bit (fp16/bf16) - // // K + V + alignment slack - // fa2_workspace_bytes += 2 * kv_bytes + 64; - // } - // } - const size_t workspace_bytes = splitkv_workspace_bytes; // const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes; @@ -1587,130 +1548,3 @@ infiniStatus_t Descriptor::calculate( } } // namespace op::paged_attention_prefill::nvidia - -// #include -// #include -// #include -// #include - -// #include "../../../devices/nvidia/nvidia_common.cuh" -// #include "../../../devices/nvidia/nvidia_kernel_common.cuh" -// #include "../cuda/kernel.cuh" -// #include "paged_attention_prefill_nvidia.cuh" - -// template -// infiniStatus_t launchPagedAttentionPrefill( -// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, -// const int64_t *block_tables, -// const int64_t *seq_lens, -// const int64_t *cum_seq_lens_q, -// const float *alibi_slopes, -// const size_t num_heads, -// const size_t num_seqs, -// const size_t num_kv_heads, -// const float scale, -// const size_t max_num_blocks_per_seq, -// const size_t block_size, -// const size_t total_q_tokens, -// const size_t head_size, -// const ptrdiff_t kv_block_stride, -// const ptrdiff_t kv_head_stride, -// const ptrdiff_t q_stride, -// const ptrdiff_t q_head_stride, -// cudaStream_t stream) { - -// if (total_q_tokens == 0 || num_heads == 0) { -// return INFINI_STATUS_BAD_TENSOR_SHAPE; -// } - -// dim3 grid(total_q_tokens, num_heads); -// dim3 block(head_size); - -// op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel -// <<>>( -// out, q, k_cache, v_cache, -// block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, -// num_heads, num_kv_heads, scale, -// max_num_blocks_per_seq, block_size, -// kv_block_stride, kv_head_stride, -// q_stride, q_head_stride, -// head_size, -// num_seqs); - -// return INFINI_STATUS_SUCCESS; -// } - -// namespace op::paged_attention_prefill::nvidia { - -// struct Descriptor::Opaque { -// std::shared_ptr internal; -// }; - -// Descriptor::~Descriptor() { -// delete _opaque; -// } - -// infiniStatus_t Descriptor::create( -// infiniopHandle_t handle, -// Descriptor **desc_ptr, -// infiniopTensorDescriptor_t out_desc, -// infiniopTensorDescriptor_t q_desc, -// infiniopTensorDescriptor_t k_cache_desc, -// infiniopTensorDescriptor_t v_cache_desc, -// infiniopTensorDescriptor_t block_tables_desc, -// infiniopTensorDescriptor_t seq_lens_desc, -// infiniopTensorDescriptor_t cum_seq_lens_q_desc, -// const std::optional &alibi_slopes_desc, -// float scale) { - -// auto info = PagedAttentionPrefillInfo::create( -// out_desc, q_desc, k_cache_desc, v_cache_desc, -// block_tables_desc, seq_lens_desc, -// cum_seq_lens_q_desc, -// alibi_slopes_desc, scale); - -// CHECK_RESULT(info); - -// *desc_ptr = new Descriptor( -// new Opaque{reinterpret_cast(handle)->internal()}, -// info.take(), 0, handle->device, handle->device_id); - -// return INFINI_STATUS_SUCCESS; -// } - -// infiniStatus_t Descriptor::calculate( -// void *workspace, size_t workspace_size, -// void *out, const void *q, const void *k_cache, const void *v_cache, -// const void *block_tables, -// const void *seq_lens, -// const void *cum_seq_lens_q, -// const void *alibi_slopes, -// void *stream_) const { - -// cudaStream_t stream = (cudaStream_t)stream_; - -// #define LAUNCH_KERNEL(Tdata, Tcompute) \ -// launchPagedAttentionPrefill( \ -// (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ -// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ -// (const float *)alibi_slopes, \ -// _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ -// _info.scale, _info.max_num_blocks_per_seq, \ -// _info.block_size, _info.total_q_tokens, \ -// _info.head_size, \ -// _info.kv_block_stride, _info.kv_head_stride, \ -// _info.q_stride, _info.q_head_stride, \ -// stream) - -// if (_info.dtype == INFINI_DTYPE_F16) { -// return LAUNCH_KERNEL(half, float); -// } else if (_info.dtype == INFINI_DTYPE_BF16) { -// return LAUNCH_KERNEL(__nv_bfloat16, float); -// } else if (_info.dtype == INFINI_DTYPE_F32) { -// return LAUNCH_KERNEL(float, float); -// } - -// return INFINI_STATUS_BAD_TENSOR_DTYPE; -// } - -// } // namespace op::paged_attention_prefill::nvidia From 7a18d2413807fb301de5c918f0cdb73a1a2ee9b6 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 02:33:08 +0000 Subject: [PATCH 22/25] issue/983 - adapted the optimized paged attention to metax --- src/infiniop/devices/metax/metax_ht2mc.h | 3 + .../devices/metax/metax_kernel_common.h | 6 + .../metax/paged_attention_hd128.maca | 1028 +++++++++++ .../metax/paged_attention_hd64.maca | 528 ++++++ .../metax/paged_attention_metax.h | 8 + .../metax/paged_attention_metax.maca | 218 +++ src/infiniop/ops/paged_attention/operator.cc | 30 +- .../cuda/kernel_v2.cuh | 2 + .../metax/paged_attention_prefill_metax.h | 8 + .../metax/paged_attention_prefill_metax.maca | 1554 +++++++++++++++++ .../ops/paged_attention_prefill/operator.cc | 15 + .../paged_caching/metax/paged_caching_metax.h | 8 + .../metax/paged_caching_metax.maca | 157 ++ src/infiniop/ops/paged_caching/operator.cc | 30 +- 14 files changed, 3565 insertions(+), 30 deletions(-) create mode 100644 src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca create mode 100644 src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca create mode 100644 src/infiniop/ops/paged_attention/metax/paged_attention_metax.h create mode 100644 src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca create mode 100644 src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h create mode 100644 src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca create mode 100644 src/infiniop/ops/paged_caching/metax/paged_caching_metax.h create mode 100644 src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca diff --git a/src/infiniop/devices/metax/metax_ht2mc.h b/src/infiniop/devices/metax/metax_ht2mc.h index a1c8c1ffe..447792b67 100644 --- a/src/infiniop/devices/metax/metax_ht2mc.h +++ b/src/infiniop/devices/metax/metax_ht2mc.h @@ -85,6 +85,9 @@ #define hcclSuccess mcclSuccess #define hcclCommDestroy mcclCommDestroy #define hcclAllReduce mcclAllReduce +#define hcGetDevice mcGetDevice +#define hcDeviceAttributeMultiProcessorCount mcDeviceAttributeMultiProcessorCount +#define hcDeviceGetAttribute mcDeviceGetAttribute #define hcStreamCaptureMode mcStreamCaptureMode #define hcStreamCaptureModeGlobal mcStreamCaptureModeGlobal #define hcStreamCaptureModeThreadLocal mcStreamCaptureModeThreadLocal diff --git a/src/infiniop/devices/metax/metax_kernel_common.h b/src/infiniop/devices/metax/metax_kernel_common.h index d850e9d04..f58fe6c53 100644 --- a/src/infiniop/devices/metax/metax_kernel_common.h +++ b/src/infiniop/devices/metax/metax_kernel_common.h @@ -19,6 +19,12 @@ using cuda_bfloat16 = hpcc_bfloat16; using cuda_bfloat162 = hpcc_bfloat162; using cuda_fp8_e4m3 = __hpcc_fp8_e4m3; +#ifdef ENABLE_METAX_MC_API +using __nv_bfloat16 = __maca_bfloat16; +#else +using __nv_bfloat16 = __hpcc_bfloat16; +#endif + namespace device::metax { // get the memory offset of the given element in a tensor given its flat index diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca b/src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca new file mode 100644 index 000000000..131ac2343 --- /dev/null +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca @@ -0,0 +1,1028 @@ +#ifdef ENABLE_METAX_MC_API +#include +#else +#include +#endif + +#include +#include +#include +#include + +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" + +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention::metax { + +namespace { +constexpr int kMaxSplits = 8; + +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline int getSmCount() { + int device = 0; + if (hcGetDevice(&device) != hcSuccess) { + return 0; + } + int sm_count = 0; + if (hcDeviceGetAttribute(&sm_count, hcDeviceAttributeMultiProcessorCount, device) != hcSuccess) { + return 0; + } + return sm_count; +} + +// A lightweight FA2-style "waves" heuristic. +// +// Important: our split-kv kernel shards the KV sequence length, so the main "work" +// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k +// (max pages * page size), which matches common decode microbench where all seqs +// share the same cache length. +inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) { + if (sm_count <= 0) { + return 1; + } + if (num_heads == 0 || num_seqs == 0) { + return 1; + } + if (seqlen_k <= 256) { + return 1; + } + + const size_t base_blocks = num_heads * num_seqs; + int best_splits = 1; + // Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens. + size_t best_score = (ceilDiv(base_blocks, static_cast(sm_count)) * seqlen_k); + + size_t prev_work_per_block = seqlen_k; + for (int s = 2; s <= kMaxSplits; ++s) { + const size_t blocks = base_blocks * static_cast(s); + const size_t waves_split = ceilDiv(blocks, static_cast(sm_count)); + const size_t work_per_block = ceilDiv(seqlen_k, static_cast(s)); + // If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant. + if (work_per_block == prev_work_per_block) { + continue; + } + prev_work_per_block = work_per_block; + // Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit. + const size_t waves_combine = ceilDiv(base_blocks, static_cast(sm_count)); + const size_t score = waves_split * work_per_block + waves_combine; + if (score < best_score) { + best_score = score; + best_splits = s; + } + } + return best_splits; +} +} // namespace + +inline bool envBool(const char *name) { + if (const char *env = std::getenv(name)) { + return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + return false; +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeWarpKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Cta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Default CTA variant (lower overhead). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128CtaTile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Cta32( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Experimental 1-warp CTA variant for head_dim=128 (kPack=4). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Cta32Tile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128CtaGqa4( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // GQA fused kernel: CTA computes 4 query heads for one KV head (head_dim=128). + op::paged_attention::cuda::flashAttentionDecodeCtaGqaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCta( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCtaTile16( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCta32( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCta32Tile16( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, o_stride); +} + +template +infiniStatus_t launch_decode_hd128_impl( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + + // Default decode config (2026-01-22): + // decode_flash_cta8_64_gqa_splitkv_4 + // Users can override any knob via the corresponding INFINIOP_FLASH_* env vars. + bool use_cta = true; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) { + // Backward-compatible: any non-"cta" value means "warp". + use_cta = (std::strcmp(env, "cta") == 0); + } + bool use_gqa_fused = true; + if (const char *env = std::getenv("INFINIOP_FLASH_GQA_FUSED")) { + if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) { + use_gqa_fused = false; + } else { + use_gqa_fused = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + } + int cta_tile = 8; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) { + const int v = std::atoi(env); + if (v == 8 || v == 16) { + cta_tile = v; + } + } + int cta_threads = 64; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_THREADS")) { + const int v = std::atoi(env); + if (v == 32 || v == 64) { + cta_threads = v; + } + } + dim3 block(use_cta ? static_cast(cta_threads) : 32); + + bool use_split = true; + bool use_split_auto = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + if (std::strcmp(env, "auto") == 0) { + use_split_auto = true; + use_split = false; + } else { + if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) { + use_split = false; + } else { + use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + } + } + int num_splits = 4; + bool fixed_num_splits = true; + if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) { + if (std::strcmp(env, "auto") == 0) { + fixed_num_splits = false; + } else { + num_splits = std::atoi(env); + fixed_num_splits = (num_splits > 0); + } + } + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + const bool debug_dispatch = envBool("INFINIOP_FLASH_DEBUG_DISPATCH"); + auto dump_dispatch = [&](const char *path) { + if (!debug_dispatch) { + return; + } + // Avoid spamming: only print when the key dispatch signature changes. + struct Sig { + const char *path; + int dtype; + size_t heads; + size_t kv_heads; + size_t seqs; + size_t pbs; + size_t max_blocks; + int cta_tile; + int cta_threads; + int split; + int split_auto; + int num_splits; + int fixed; + int gqa_fused; + }; + static Sig last{}; + static bool has_last = false; + + Sig cur{ + path, + static_cast(dtype), + num_heads, + num_kv_heads, + num_seqs, + page_block_size, + max_num_blocks_per_seq, + cta_tile, + cta_threads, + static_cast(use_split), + static_cast(use_split_auto), + num_splits, + static_cast(fixed_num_splits), + static_cast(use_gqa_fused), + }; + + if (has_last && cur.path == last.path && cur.dtype == last.dtype && cur.heads == last.heads && cur.kv_heads == last.kv_heads && cur.seqs == last.seqs && cur.pbs == last.pbs && cur.max_blocks == last.max_blocks && cur.cta_tile == last.cta_tile && cur.cta_threads == last.cta_threads && cur.split == last.split && cur.split_auto == last.split_auto && cur.num_splits == last.num_splits && cur.fixed == last.fixed && cur.gqa_fused == last.gqa_fused) { + return; + } + last = cur; + has_last = true; + + fprintf(stderr, + "[INFINIOP][paged_attention][hd128] dispatch: path=%s dtype=%d heads=%zu kv_heads=%zu seqs=%zu " + "pbs=%zu max_blocks=%zu cta_tile=%d cta_threads=%d split=%d split_auto=%d num_splits=%d fixed=%d gqa_fused=%d\n", + path, static_cast(dtype), num_heads, num_kv_heads, num_seqs, + page_block_size, max_num_blocks_per_seq, cta_tile, cta_threads, + static_cast(use_split), static_cast(use_split_auto), num_splits, static_cast(fixed_num_splits), + static_cast(use_gqa_fused)); + }; + + // Split-kv auto mode: decide whether to split based on a heuristic. + if (use_split_auto) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto(mode): sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + // If auto picks 1, fall back to non-split to avoid extra workspace and kernel overhead. + use_split = (num_splits > 1); + } + + // const bool debug_dispatch = [] { + // if (const char *env = std::getenv("INFINIOP_FLASH_DEBUG_DISPATCH")) { + // return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + // } + // return false; + // }(); + + // const char *selected_path = "unknown"; + + // Optional: fuse GQA groups (4) when seqlen_q=1 decode and alibi is disabled. + // This reuses K/V loads across query heads that share the same KV head. + // Controlled by INFINIOP_FLASH_GQA_FUSED (default: enabled). + if (use_gqa_fused && use_cta && !use_split && alibi_slopes == nullptr && num_kv_heads > 0 && num_heads == num_kv_heads * 4) { + dump_dispatch("cta_gqa_fused"); + dim3 grid_gqa(static_cast(num_kv_heads), static_cast(num_seqs), 1); + if (dtype == INFINI_DTYPE_F16) { + flashAttentionDecodeHd128CtaGqa4<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, nullptr, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + flashAttentionDecodeHd128CtaGqa4<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, nullptr, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + dim3 grid(static_cast(num_heads), static_cast(num_seqs), 1); + if (use_split) { + dump_dispatch(use_cta ? "splitkv_cta" : "splitkv_warp"); + // } + if (!fixed_num_splits) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + } + + const size_t n = num_seqs * num_heads; + const size_t acc_elems = static_cast(kMaxSplits) * n * 128; + const size_t m_elems = static_cast(kMaxSplits) * n; + const size_t l_elems = static_cast(kMaxSplits) * n; + const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float); + if (workspace == nullptr || workspace_size < needed_bytes) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *ws = static_cast(workspace); + float *partial_acc = ws; + float *partial_m = partial_acc + acc_elems; + float *partial_l = partial_m + m_elems; + + dim3 grid_split(static_cast(num_heads), static_cast(num_seqs), static_cast(num_splits)); + dim3 block_split(use_cta ? static_cast(cta_threads) : 32); + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_threads == 32) { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCta32Tile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta32<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } else { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCtaTile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } + } else { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + flashAttentionDecodeHd128SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_threads == 32) { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCta32Tile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta32<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } else { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCtaTile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } + } else { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + flashAttentionDecodeHd128SplitKvCombine<__nv_bfloat16><<>>( + static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + dump_dispatch(use_cta ? "cta_nosplit" : "warp_nosplit"); + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_tile == 16) { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32Tile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128CtaTile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128Cta<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } + } else { + flashAttentionDecodeHd128Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_tile == 16) { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32Tile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128CtaTile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128Cta<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } + } else { + flashAttentionDecodeHd128Warp<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +infiniStatus_t launch_decode_hd128_i64( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int64_t *block_tables, + const int64_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd128_i32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int32_t *block_tables, + const int32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd128_u32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const uint32_t *block_tables, + const uint32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +} // namespace op::paged_attention::metax diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca b/src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca new file mode 100644 index 000000000..2f8b95b3a --- /dev/null +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca @@ -0,0 +1,528 @@ +#ifdef ENABLE_METAX_MC_API +#include +#else +#include +#endif + +#include +#include +#include +#include + +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" + +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention::metax { + +namespace { +constexpr int kMaxSplits = 8; + +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline int getSmCount() { + int device = 0; + if (hcGetDevice(&device) != hcSuccess) { + return 0; + } + int sm_count = 0; + if (hcDeviceGetAttribute(&sm_count, hcDeviceAttributeMultiProcessorCount, device) != hcSuccess) { + return 0; + } + return sm_count; +} + +// A lightweight FA2-style "waves" heuristic. +// +// Important: our split-kv kernel shards the KV sequence length, so the main "work" +// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k +// (max pages * page size), which matches common decode microbench where all seqs +// share the same cache length. +inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) { + if (sm_count <= 0) { + return 1; + } + if (num_heads == 0 || num_seqs == 0) { + return 1; + } + if (seqlen_k <= 256) { + return 1; + } + + const size_t base_blocks = num_heads * num_seqs; + int best_splits = 1; + // Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens. + size_t best_score = (ceilDiv(base_blocks, static_cast(sm_count)) * seqlen_k); + + size_t prev_work_per_block = seqlen_k; + for (int s = 2; s <= kMaxSplits; ++s) { + const size_t blocks = base_blocks * static_cast(s); + const size_t waves_split = ceilDiv(blocks, static_cast(sm_count)); + const size_t work_per_block = ceilDiv(seqlen_k, static_cast(s)); + // If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant. + if (work_per_block == prev_work_per_block) { + continue; + } + prev_work_per_block = work_per_block; + // Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit. + const size_t waves_combine = ceilDiv(base_blocks, static_cast(sm_count)); + const size_t score = waves_split * work_per_block + waves_combine; + if (score < best_score) { + best_score = score; + best_splits = s; + } + } + return best_splits; +} +} // namespace + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd64Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeWarpKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd64Cta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Default CTA variant (lower overhead). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd64CtaTile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd64SplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_METAX_KERNEL flashAttentionDecodeHd64SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, o_stride); +} + +template +infiniStatus_t launch_decode_hd64_impl( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + + dim3 grid(static_cast(num_heads), static_cast(num_seqs), 1); + bool use_cta = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) { + use_cta = (std::strcmp(env, "cta") == 0); + } + int cta_tile = 8; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) { + const int v = std::atoi(env); + if (v == 8 || v == 16) { + cta_tile = v; + } + } + // For head_dim=64 we use a 1-warp CTA (32 threads) with packed loads. + dim3 block(32); + + bool use_split = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 4; + bool fixed_num_splits = false; + if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) { + if (std::strcmp(env, "auto") == 0) { + fixed_num_splits = false; + } else { + num_splits = std::atoi(env); + fixed_num_splits = (num_splits > 0); + } + } + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + if (use_split) { + if (use_cta) { + // We currently only implement the split-kv path with warp kernels. + // The CTA kernel is a separate non-split implementation. + static bool warned = false; + if (!warned) { + warned = true; + fprintf(stderr, + "[INFINIOP][paged_attention] split-kv is enabled; ignoring INFINIOP_FLASH_DECODE_KERNEL=cta " + "(CTA split-kv not implemented yet)\n"); + } + } + + if (!fixed_num_splits) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + } + + const size_t n = num_seqs * num_heads; + const size_t acc_elems = static_cast(kMaxSplits) * n * 64; + const size_t m_elems = static_cast(kMaxSplits) * n; + const size_t l_elems = static_cast(kMaxSplits) * n; + const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float); + if (workspace == nullptr || workspace_size < needed_bytes) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *ws = static_cast(workspace); + float *partial_acc = ws; + float *partial_m = partial_acc + acc_elems; + float *partial_l = partial_m + m_elems; + + dim3 grid_split(static_cast(num_heads), static_cast(num_seqs), static_cast(num_splits)); + dim3 block_split(32); + + if (dtype == INFINI_DTYPE_F16) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<__nv_bfloat16><<>>( + static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_tile == 16) { + flashAttentionDecodeHd64CtaTile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd64Cta<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + flashAttentionDecodeHd64Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_tile == 16) { + flashAttentionDecodeHd64CtaTile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd64Cta<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + flashAttentionDecodeHd64Warp<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +infiniStatus_t launch_decode_hd64_i64( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int64_t *block_tables, + const int64_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd64_i32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int32_t *block_tables, + const int32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd64_u32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const uint32_t *block_tables, + const uint32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + hcStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +} // namespace op::paged_attention::metax diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_metax.h b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.h new file mode 100644 index 000000000..82a5b3e59 --- /dev/null +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.h @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_METAX_H__ +#define __PAGED_ATTENTION_METAX_H__ + +#include "../paged_attention.h" + +DESCRIPTOR(metax) + +#endif // __PAGED_ATTENTION_METAX_H__ diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca new file mode 100644 index 000000000..fd0f4d576 --- /dev/null +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca @@ -0,0 +1,218 @@ +#ifdef ENABLE_METAX_MC_API +#include +#else +#include +#endif + +#include +#include +#include + +#include "../../../devices/metax/metax_common.h" +#include "paged_attention_metax.h" + +namespace op::paged_attention::metax { + +infiniStatus_t launch_decode_hd64_i64( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + hcStream_t stream); + +infiniStatus_t launch_decode_hd64_i32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + hcStream_t stream); + +infiniStatus_t launch_decode_hd64_u32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + hcStream_t stream); + +infiniStatus_t launch_decode_hd128_i64( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + hcStream_t stream); + +infiniStatus_t launch_decode_hd128_i32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + hcStream_t stream); + +infiniStatus_t launch_decode_hd128_u32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + hcStream_t stream); + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + const std::optional &alibi_slopes_desc, + float scale) { + + auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale); + CHECK_RESULT(info_res); + auto info = info_res.take(); + // Reserve workspace for optional split-kv decode (partial acc + m/l). + // Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits. + constexpr size_t kMaxSplits = 8; + const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float); + const size_t workspace_bytes = kMaxSplits * per_split; + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info, workspace_bytes, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *cache_lens, const void *alibi_slopes, + void *stream_) const { + + bool need_workspace = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + // "auto" may enable split-kv depending on the runtime heuristic. + need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } else { + // Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace. + need_workspace = (_info.head_size == 128); + } + if (need_workspace && workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto stream = static_cast(stream_); + + const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast(alibi_slopes); + + if (_info.index_dtype == INFINI_DTYPE_I64) { + const auto *block_table_i64 = static_cast(block_tables); + const auto *cache_lens_i64 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_i64( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i64, cache_lens_i64, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_i64( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i64, cache_lens_i64, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (_info.index_dtype == INFINI_DTYPE_I32) { + const auto *block_table_i32 = static_cast(block_tables); + const auto *cache_lens_i32 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_i32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i32, cache_lens_i32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_i32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i32, cache_lens_i32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (_info.index_dtype == INFINI_DTYPE_U32) { + const auto *block_table_u32 = static_cast(block_tables); + const auto *cache_lens_u32 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_u32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_u32, cache_lens_u32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_u32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_u32, cache_lens_u32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::paged_attention::metax diff --git a/src/infiniop/ops/paged_attention/operator.cc b/src/infiniop/ops/paged_attention/operator.cc index 1d7d4fee3..46bea9e1e 100644 --- a/src/infiniop/ops/paged_attention/operator.cc +++ b/src/infiniop/ops/paged_attention/operator.cc @@ -5,9 +5,9 @@ #ifdef ENABLE_NVIDIA_API #include "nvidia/paged_attention_nvidia.cuh" #endif -// #ifdef ENABLE_METAX_API -// #include "metax/paged_attention_metax.h" -// #endif +#ifdef ENABLE_METAX_API +#include "metax/paged_attention_metax.h" +#endif __C infiniStatus_t infiniopCreatePagedAttentionDescriptor( infiniopHandle_t handle, @@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor( #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CREATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( #ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // GET(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention( #ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CALCULATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( #ifdef ENABLE_NVIDIA_API DESTROY(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // DESTROY(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh index 6790f12d8..28bcccaeb 100644 --- a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh @@ -1,10 +1,12 @@ #ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ #define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ +#ifdef ENABLE_NVIDIA_API #include #include #include #include +#endif #include #include diff --git a/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h new file mode 100644 index 000000000..03b6cef3c --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_PREFILL_METAX_H__ +#define __PAGED_ATTENTION_PREFILL_METAX_H__ + +#include "../paged_attention_prefill.h" + +DESCRIPTOR(metax) + +#endif // __PAGED_ATTENTION_PREFILL_METAX_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca new file mode 100644 index 000000000..b4bd01e8d --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca @@ -0,0 +1,1554 @@ +#ifdef ENABLE_METAX_MC_API +#include +#else +#include +#endif + +#include +#include +#include +#include + +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" + +// #include "paged_attention_prefill_fa2.cuh" +#include "paged_attention_prefill_metax.h" + +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention_prefill::metax { + +namespace { +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) { + // Heuristic auto-dispatch (v0.4): + // - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256. + // - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80). + // + // Users can always override via INFINIOP_FLASH_PREFILL_KERNEL. + if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) { + if (info.head_size == 128) { + return "warpcta8pipe"; + } + // For head_size=64 we keep the previous default until we have broader perf coverage. + } + return "warpcta8"; +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel). + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride, + q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel). + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride, + q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 4 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 4 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages. + // Note: we keep K in shared memory but load V from global to stay within the per-block shared limit. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, with cp.async pipelining. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma( + half *out, + const half *q, + const half *k_cache, + const half *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, with cp.async pipelining. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + size_t total_q_tokens, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + // Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch: + // blockIdx.z in [0, num_splits * num_m_blocks). + const int num_m_blocks = static_cast((total_q_tokens + 8 - 1) / 8); + const int bz = static_cast(blockIdx.z); + const int split_idx = bz / num_m_blocks; + const int m_block = bz - split_idx * num_m_blocks; + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + size_t total_q_tokens, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + const int num_m_blocks = static_cast((total_q_tokens + 8 - 1) / 8); + const int bz = static_cast(blockIdx.z); + const int split_idx = bz / num_m_blocks; + const int m_block = bz - split_idx * num_m_blocks; + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 16 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 16 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +infiniStatus_t launch_prefill_ref( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + const dim3 grid(static_cast(total_q_tokens), static_cast(num_heads), 1); + const dim3 block(static_cast(head_size), 1, 1); + + if (head_size == 64) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, num_seqs); + return INFINI_STATUS_SUCCESS; + } + + if (head_size == 128) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, num_seqs); + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_SHAPE; +} + +template +infiniStatus_t launch_prefill_warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + const dim3 block(32, 1, 1); + // Global-token launch: + // - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch + // - matches PagedAttention varlen (cu_seqlens) mental model better + const dim3 grid(static_cast(num_heads), + static_cast(total_q_tokens), + 1); + + switch (head_size) { + case 64: + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq, + page_block_size, block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq, + page_block_size, block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + constexpr int kWarps = 4; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta8 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8Pipe + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta8Pipe + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8mma( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + // Current WMMA kernel only supports fp16 + head_dim=128. + if constexpr (!std::is_same_v) { + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + + if (head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts. + // Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1. + const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE"); + const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0); + const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size; + if (!force_mma && seqlen_k_est > 4096) { + static bool warned = false; + if (!warned) { + std::fprintf(stderr, + "[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). " + "Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n", + seqlen_k_est); + warned = true; + } + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + + // WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel. + int device = 0; + hcDeviceProp_t prop{}; + if (hcGetDevice(&device) == hcSuccess && hcGetDeviceProperties(&prop, device) == hcSuccess) { + if (prop.major < 7) { + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + } + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(16)))); + + PagedAttentionPrefillHd128WarpCta8Mma + <<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launch_prefill_warpcta8pipe_splitkv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + constexpr int kMaxSplits = 8; + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast(kWarps)); + // Single kernel launch with split_idx encoded in grid.z: + // blockIdx.z in [0, num_splits * num_m_blocks). + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(num_m_blocks * static_cast(num_splits))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8PipeSplitKv + <<>>( + partial_acc, partial_m, partial_l, num_splits, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); + break; + case 128: + PagedAttentionPrefillHd128WarpCta8PipeSplitKv + <<>>( + partial_acc, partial_m, partial_l, num_splits, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); + break; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Combine: one warp per (token, head). + const dim3 block2(32); + const dim3 grid2(static_cast(num_heads), static_cast(total_q_tokens), 1); + switch (head_size) { + case 64: + PagedAttentionPrefillHd64SplitKvCombine + <<>>( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128SplitKvCombine + <<>>( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8n128( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + // Only meaningful for head_dim=128. + if (head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + PagedAttentionPrefillHd128WarpCta8N128 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launch_prefill_warpcta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + hcStream_t stream) { + + constexpr int kWarps = 16; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta16 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta16 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} +} // namespace + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t total_kv_lens_desc, + infiniopTensorDescriptor_t cum_seqlens_q_desc, + const std::optional &alibi_slopes_desc, + float scale) { + + auto info = PagedAttentionPrefillInfo::create( + out_desc, q_desc, k_cache_desc, v_cache_desc, + block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc, + alibi_slopes_desc, scale); + CHECK_RESULT(info); + + // Optional split-kv prefill requires workspace for partial (m, l, acc). + // IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve + // a huge workspace unless the user explicitly enables split-kv. + bool use_splitkv = false; + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { + use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 1; + if (use_splitkv) { + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) { + const int v = std::atoi(env); + if (v > 0) { + num_splits = v; + } + } else { + num_splits = 4; + } + constexpr int kMaxSplits = 8; + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + } + const size_t n = info->total_q_tokens * info->num_heads; + const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0; + + const size_t workspace_bytes = splitkv_workspace_bytes; + // const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes; + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_bytes, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, + const void *total_kv_lens, + const void *cum_seqlens_q, + const void *alibi_slopes, + void *stream_) const { + auto stream = static_cast(stream_); + + const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast(alibi_slopes); + const auto *total_kv_lens_i64 = static_cast(total_kv_lens); + const auto *cu_seqlens_q_i64 = static_cast(cum_seqlens_q); + + bool use_splitkv = false; + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { + use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 1; + if (use_splitkv) { + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) { + const int v = std::atoi(env); + if (v > 0) { + num_splits = v; + } + } else { + // Conservative default; users can override. + num_splits = 4; + } + constexpr int kMaxSplits = 8; + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + const size_t n = _info.total_q_tokens * _info.num_heads; + const size_t required = static_cast(num_splits) * n * (_info.head_size + 2) * sizeof(float); + if (workspace_size < required) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + } + + if (use_splitkv) { + const size_t n = _info.total_q_tokens * _info.num_heads; + float *partial_acc = static_cast(workspace); + float *partial_m = partial_acc + static_cast(num_splits) * n * _info.head_size; + float *partial_l = partial_m + static_cast(num_splits) * n; + + // Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64. +#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \ + return launch_prefill_warpcta8pipe_splitkv( \ + partial_acc, partial_m, partial_l, num_splits, \ + static_cast(out), \ + static_cast(q), \ + static_cast(k_cache), \ + static_cast(v_cache), \ + static_cast(BT_PTR), \ + total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream) + + if (_info.dtype == INFINI_DTYPE_F16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_SPLITKV(int64_t, half, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_SPLITKV(int32_t, half, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_SPLITKV(uint32_t, half, block_tables); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (_info.dtype == INFINI_DTYPE_BF16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + +#undef DISPATCH_SPLITKV + } + +// Default to the fastest validated kernel for supported shapes. +// "ref" is still available for debugging/correctness bisecting. +#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \ + do { \ + const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \ + const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \ + if (k_env != nullptr) { \ + const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \ + if (!known) { \ + const char *fallback = default_prefill_kernel(_info); \ + std::fprintf(stderr, \ + "[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \ + k, fallback); \ + k = fallback; \ + } \ + } \ + const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \ + static bool printed_dispatch = false; \ + if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \ + std::fprintf(stderr, \ + "[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \ + k, \ + (k_env == nullptr ? "auto" : "env"), \ + static_cast(_info.head_size), \ + static_cast(_info.page_block_size), \ + static_cast(_info.dtype)); \ + printed_dispatch = true; \ + } \ + if (std::strcmp(k, "warp") == 0) { \ + return launch_prefill_warp( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta") == 0) { \ + return launch_prefill( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta8") == 0) { \ + return launch_prefill_warpcta8( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta8pipe") == 0) { \ + return launch_prefill_warpcta8pipe( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if constexpr (std::is_same_v) { \ + if (std::strcmp(k, "warpcta8mma") == 0) { \ + return launch_prefill_warpcta8mma( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + } \ + if (std::strcmp(k, "warpcta8n128") == 0) { \ + return launch_prefill_warpcta8n128( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta16") == 0) { \ + return launch_prefill_warpcta16( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "ref") == 0) { \ + return launch_prefill_ref( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + return INFINI_STATUS_BAD_PARAM; \ + } while (false) + +#define DISPATCH_INDEX(Tindex) \ + do { \ + if (_info.dtype == INFINI_DTYPE_F16) { \ + DISPATCH_KERNEL(Tindex, half, float); \ + } \ + if (_info.dtype == INFINI_DTYPE_BF16) { \ + DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \ + } \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } while (false) + + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_INDEX(int64_t); + } else if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_INDEX(int32_t); + } else if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_INDEX(uint32_t); + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::paged_attention_prefill::nvidia diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc index e205acca1..af21df651 100644 --- a/src/infiniop/ops/paged_attention_prefill/operator.cc +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -5,6 +5,9 @@ #ifdef ENABLE_NVIDIA_API #include "nvidia/paged_attention_prefill_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/paged_attention_prefill_metax.h" +#endif __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( infiniopHandle_t handle, @@ -32,6 +35,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( switch (handle->device) { #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -50,6 +56,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -75,6 +84,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill( switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/paged_caching/metax/paged_caching_metax.h b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.h new file mode 100644 index 000000000..7ac3fda2c --- /dev/null +++ b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.h @@ -0,0 +1,8 @@ +#ifndef __PAGED_CACHING_METAX_H__ +#define __PAGED_CACHING_METAX_H__ + +#include "../paged_caching.h" + +DESCRIPTOR(metax) + +#endif // __PAGED_CACHING_METAX_H__ diff --git a/src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca new file mode 100644 index 000000000..db761992f --- /dev/null +++ b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca @@ -0,0 +1,157 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/kernel.cuh" +#include "paged_caching_metax.h" + +template +INFINIOP_METAX_KERNEL pagedCaching( + Tdata *k_cache, Tdata *v_cache, + const Tdata *k, const Tdata *v, + const int64_t *slot_mapping, + const size_t head_size, const size_t block_size, + const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, + const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { + op::paged_caching::cuda::pagedCachingKernel( + k_cache, v_cache, k, v, slot_mapping, head_size, + block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); +} + +namespace op::paged_caching::metax { +// PIMPL struct definition +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor implementation +Descriptor::~Descriptor() { + delete _opaque; +} + +// Static factory method implementation +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t slot_mapping_desc) { + + auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc); + CHECK_RESULT(info); + + // Create and return the Descriptor instance. + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +// The launchKernel function is a templated helper to encapsulate the kernel launch. +// It sets up grid/block dimensions and calls the device-side kernel. +template +infiniStatus_t launchKernel(const PagedCachingInfo &info, + void *k_cache, void *v_cache, + infiniDtype_t dtype, + const void *k, const void *v, + const void *slot_mapping, + size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, + ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, + ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, + hcStream_t stream) { + + // Grid dimension is 1D, with one block per token, as we decided. + dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1); + // Block dimension is 1D, using the number of threads specified at compile time. + dim3 block(NUM_THREADS); + + // This kernel does not require dynamic shared memory. + size_t shared_mem_size = 0; + + // Launch the device-side kernel. + if (dtype == INFINI_DTYPE_F16) { + pagedCaching + <<>>( + (half *)k_cache, + (half *)v_cache, + (const half *)k, + (const half *)v, + (const int64_t *)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride); + } else if (dtype == INFINI_DTYPE_BF16) { + pagedCaching + <<>>( + (cuda_bfloat16 *)k_cache, + (cuda_bfloat16 *)v_cache, + (const cuda_bfloat16 *)k, + (const cuda_bfloat16 *)v, + (const int64_t *)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride); + } else if (dtype == INFINI_DTYPE_F32) { + pagedCaching + <<>>( + (float *)k_cache, + (float *)v_cache, + (const float *)k, + (const float *)v, + (const int64_t *)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +// Execution method implementation +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *k_cache, void *v_cache, + const void *k, const void *v, + const void *slot_mapping, + void *stream_) const { + + hcStream_t stream = (hcStream_t)stream_; + + // Dispatch logic based on the device's maximum threads per block. + // This allows selecting the largest, most efficient block size the hardware supports. + int max_threads = _opaque->internal->maxThreadsPerBlock(); + if (max_threads >= METAX_BLOCK_SIZE_1024) { + // Dispatch based on data type for a 1024-thread block. + launchKernel( + _info, k_cache, v_cache, _info.dtype, k, v, slot_mapping, + _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, + _info.k_src_stride, _info.v_src_stride, + _info.k_cache_block_stride, _info.v_cache_block_stride, + stream); + } else if (max_threads >= METAX_BLOCK_SIZE_512) { + launchKernel( + _info, k_cache, v_cache, _info.dtype, k, v, slot_mapping, + _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, + _info.k_src_stride, _info.v_src_stride, + _info.k_cache_block_stride, _info.v_cache_block_stride, + stream); + } else { + // If the device supports fewer threads, return an error. + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::paged_caching::metax diff --git a/src/infiniop/ops/paged_caching/operator.cc b/src/infiniop/ops/paged_caching/operator.cc index 3bfd92280..6eb746f9f 100644 --- a/src/infiniop/ops/paged_caching/operator.cc +++ b/src/infiniop/ops/paged_caching/operator.cc @@ -5,9 +5,9 @@ #ifdef ENABLE_NVIDIA_API #include "nvidia/paged_caching_nvidia.cuh" #endif -// #ifdef ENABLE_METAX_API -// #include "metax/paged_caching_metax.h" -// #endif +#ifdef ENABLE_METAX_API +#include "metax/paged_caching_metax.h" +#endif __C infiniStatus_t infiniopCreatePagedCachingDescriptor( infiniopHandle_t handle, @@ -29,9 +29,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor( #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CREATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -50,9 +50,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( #ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // GET(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -75,9 +75,9 @@ __C infiniStatus_t infiniopPagedCaching( #ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CALCULATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -95,9 +95,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor( #ifdef ENABLE_NVIDIA_API DESTROY(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // DESTROY(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } From 1fa56298c0bdee22fe5e38a80cde064ac67588bf Mon Sep 17 00:00:00 2001 From: wooway777 Date: Mon, 26 Jan 2026 20:40:08 +0800 Subject: [PATCH 23/25] demo131 - patch lua flags and includes --- .../ops/flash_attention/ninetoothed/build.py | 12 +++++------- xmake/hygon.lua | 2 +- xmake/iluvatar.lua | 2 +- xmake/metax.lua | 2 +- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py index 23f265e2e..00f96b9c9 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/build.py +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -6,16 +6,14 @@ import torch +import os -def build(): - if torch.cuda.is_available(): - device_count = torch.cuda.device_count() - for i in range(device_count): - device_name = torch.cuda.get_device_name(i).lower() +def build(): - if "metax" in device_name: - return + env_vars_to_check = ["MACA_HOME", "MACA_PATH", "MACA_ROOT"] + if any(var in os.environ for var in env_vars_to_check): + return with_kv_cache_values = (0,) emb_dim_values = (16, 32, 64, 128, 256) diff --git a/xmake/hygon.lua b/xmake/hygon.lua index 05d3e8356..c29126172 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -77,7 +77,7 @@ target("infiniop-hygon") add_files("../src/infiniop/ops/swiglu/nvidia/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", {cxxflags = {"-Wno-return-type"}}) end target_end() diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index cd9304127..1bb5f6c4c 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -58,7 +58,7 @@ target("infiniop-iluvatar") add_files("../src/infiniop/ops/dequantize_awq/iluvatar/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", {cxxflags = {"-Wno-return-type"}}) end target_end() diff --git a/xmake/metax.lua b/xmake/metax.lua index 65e5d549b..432de74f7 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -54,7 +54,7 @@ target("infiniop-metax") if has_config("ninetoothed") then add_includedirs(MACA_ROOT .. "/include/mcr") - add_files("../build/ninetoothed/*.c", { + add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", { cxflags = { "-include stdlib.h", "-Wno-return-type", From 807e5e436e7a266bbce414b10e377f4b786600b3 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Tue, 27 Jan 2026 03:23:04 +0000 Subject: [PATCH 24/25] issue/811 use relax graph capture mode, add compile flag for graph instantiate --- src/infinicore/graph/graph.cc | 4 +++- src/infinicore/tensor/view.cc | 6 ++++-- xmake.lua | 12 ++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/infinicore/graph/graph.cc b/src/infinicore/graph/graph.cc index 8a06e5f40..3b8fc57e5 100644 --- a/src/infinicore/graph/graph.cc +++ b/src/infinicore/graph/graph.cc @@ -84,7 +84,7 @@ void Graph::instantiate() { if (infinirtStreamBeginCapture( context::getStream(), - INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) + INFINIRT_STREAM_CAPTURE_MODE_RELAXED) != INFINI_STATUS_SUCCESS) { return; } @@ -144,7 +144,9 @@ std::shared_ptr GraphManager::stop_recording() { return nullptr; } recording_ = false; +#ifdef USE_INFINIRT_GRAPH graph_->instantiate(); +#endif return std::exchange(graph_, nullptr); } diff --git a/src/infinicore/tensor/view.cc b/src/infinicore/tensor/view.cc index 21c4fc5cf..051ee42c0 100644 --- a/src/infinicore/tensor/view.cc +++ b/src/infinicore/tensor/view.cc @@ -2,6 +2,8 @@ #include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" +#include "../utils.hpp" + #include #include @@ -62,11 +64,11 @@ Tensor TensorImpl::narrow(const std::vector &slices) const { Tensor TensorImpl::permute(const Shape &order) const { // Validate input - assert(meta_.shape.size() == order.size()); + INFINICORE_ASSERT(meta_.shape.size() == order.size()); // Check that order contains all indices from 0 to n-1 exactly once for (size_t i = 0; i < order.size(); i++) { - assert(std::find(order.begin(), order.end(), i) != order.end()); + INFINICORE_ASSERT(std::find(order.begin(), order.end(), i) != order.end()); } // Permute shape and strides diff --git a/xmake.lua b/xmake.lua index a8e767723..a4d311a7d 100644 --- a/xmake.lua +++ b/xmake.lua @@ -205,6 +205,18 @@ if has_config("ninetoothed") then add_defines("ENABLE_NINETOOTHED") end +-- cuda graph +option("graph") + set_default(false) + set_showmenu(true) + set_description("Whether to use device graph instantiating feature, such as cuda graph for nvidia") +option_end() + +if has_config("graph") then + add_defines("USE_INFINIRT_GRAPH") +end + + -- InfiniCCL option("ccl") set_default(false) From bf0c825dfb10ccb4e506267cf9f2d4ad706caebe Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:06:10 +0800 Subject: [PATCH 25/25] issue/995 fix paged attn on iluvatar --- .../ops/paged_attention/cuda/kernel_v2.cuh | 4 ++++ src/infiniop/ops/paged_attention/operator.cc | 14 +++++++++++++- .../ops/paged_attention_prefill/operator.cc | 14 +++++++++++++- src/infiniop/ops/paged_caching/operator.cc | 14 +++++++++++++- xmake/iluvatar.lua | 6 +++--- 5 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh index e63dd68e2..2b603217b 100644 --- a/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh +++ b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh @@ -30,7 +30,11 @@ __device__ __forceinline__ float warpReduceMax(float x) { } __device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) { +#if defined(ENABLE_ILUVATAR_API) + return static_cast(reinterpret_cast(ptr)); +#else return static_cast(__cvta_generic_to_shared(ptr)); +#endif } __device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const void *src_global) { diff --git a/src/infiniop/ops/paged_attention/operator.cc b/src/infiniop/ops/paged_attention/operator.cc index 46bea9e1e..8bb603cdb 100644 --- a/src/infiniop/ops/paged_attention/operator.cc +++ b/src/infiniop/ops/paged_attention/operator.cc @@ -2,7 +2,7 @@ #include "../../handle.h" #include "infiniop/ops/paged_attention.h" -#ifdef ENABLE_NVIDIA_API +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) #include "nvidia/paged_attention_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor( #endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -57,6 +60,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( #endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -82,6 +88,9 @@ __C infiniStatus_t infiniopPagedAttention( #endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -102,6 +111,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( #endif #ifdef ENABLE_METAX_API DESTROY(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc index af21df651..207157b22 100644 --- a/src/infiniop/ops/paged_attention_prefill/operator.cc +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -2,7 +2,7 @@ #include "../../handle.h" #include "infiniop/ops/paged_attention_prefill.h" -#ifdef ENABLE_NVIDIA_API +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) #include "nvidia/paged_attention_prefill_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -38,6 +38,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( #endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -59,6 +62,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( #endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -87,6 +93,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill( #endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -107,6 +116,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( #endif #ifdef ENABLE_METAX_API DESTROY(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/paged_caching/operator.cc b/src/infiniop/ops/paged_caching/operator.cc index 6eb746f9f..3afc7a84b 100644 --- a/src/infiniop/ops/paged_caching/operator.cc +++ b/src/infiniop/ops/paged_caching/operator.cc @@ -2,7 +2,7 @@ #include "../../handle.h" #include "infiniop/ops/paged_caching.h" -#ifdef ENABLE_NVIDIA_API +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) #include "nvidia/paged_caching_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -31,6 +31,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor( #endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -52,6 +55,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( #endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -77,6 +83,9 @@ __C infiniStatus_t infiniopPagedCaching( #endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -97,6 +106,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor( #endif #ifdef ENABLE_METAX_API DESTROY(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index 1bb5f6c4c..4c641d459 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -42,14 +42,14 @@ target("infiniop-iluvatar") add_links("cudart", "cublas", "cudnn") set_warnings("all", "error") - add_cuflags("-Wno-error=unused-private-field") + add_cuflags("-Wno-error=unused-private-field", "-Wno-error=unused-variable", "-Wno-unused-variable") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) if has_config("ivcore-20") then add_cuflags("--cuda-gpu-arch=ivcore20", {force = true}) end add_culdflags("-fPIC") - add_cxflags("-fPIC") - add_cxxflags("-fPIC") + add_cxflags("-fPIC", "-Wno-error=unused-variable", "-Wno-unused-variable") + add_cxxflags("-fPIC", "-Wno-error=unused-variable", "-Wno-unused-variable") -- set_languages("cxx17") 天数似乎不能用这个配置 add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu")