Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1e63710
issue/987 - add .cpp files to ninetoothed includes
wooway777 Jan 27, 2026
822a534
issue/978 - metax cuda graph impl and wrappings
wooway777 Jan 23, 2026
cc2cc3a
issue/846 - Refactor embedding to support device-side input and CUDA …
gongchensu Dec 26, 2025
835209e
issue/900 - support embedding on iluvatar, metax, and moore
wooway777 Jan 8, 2026
eb34d4d
issue/900 - adapt to graph and adjust test script
wooway777 Jan 9, 2026
f9761a2
issue/900 - maintains classic embedding for devices yet to be worked on
wooway777 Jan 19, 2026
0c204df
issue/791 fix add_rmsnorm api and rmsnorm module
PanZezhong1725 Jan 23, 2026
dfafc21
issue/884 - add_rms_norm on iluvatar, metax and moore
wooway777 Jan 7, 2026
4ddc664
issue/632 - adapt to iluvatar core 20
wooway777 Jan 19, 2026
0611cb1
issue/791 - fix add_rmsnorm api on mtx and mth
wooway777 Jan 26, 2026
81e5fe9
issue/810 support more ops as graph op
PanZezhong1725 Jan 19, 2026
7c5aa16
issue/985 - adjust cxflags and cxxflags for lua scripts
wooway777 Jan 26, 2026
55cd22e
issue/402 - convenient ninetoothed util
voltjia Aug 25, 2025
32340fc
issue/925 - Speed up `scripts/build_ntops.py` and `src/infiniop/ninet…
voltjia Jan 14, 2026
ca58118
issue/940 - check build result and implicitly require build.py for bu…
wooway777 Jan 26, 2026
47843aa
issue/935 - add metax include dir for ninetoothed
wooway777 Jan 15, 2026
6ac8f90
issue/919 - ninetoothed flash attention
wooway777 Jan 26, 2026
5614e1b
issue/931 - ninetoothed swiglu for nv, il, mtx
wooway777 Jan 26, 2026
97eced0
issue/923 - ninetoothed kv caching for nv, il, mtx
wooway777 Jan 26, 2026
1c18c04
issue/979 optimize paged attention
PanZezhong1725 Jan 23, 2026
4cd1f68
issue/979 - removed commented paged attn codes
wooway777 Jan 26, 2026
7a18d24
issue/983 - adapted the optimized paged attention to metax
wooway777 Jan 26, 2026
1fa5629
demo131 - patch lua flags and includes
wooway777 Jan 26, 2026
807e5e4
issue/811 use relax graph capture mode, add compile flag for graph in…
PanZezhong1725 Jan 27, 2026
70862bc
Merge pull request #989 from InfiniTensor/issue/811-fix
PanZezhong1725 Jan 27, 2026
bf0c825
issue/995 fix paged attn on iluvatar
zhangyue207 Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions include/infinicore/graph/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ class GraphTensor : public Tensor {
};

class GraphOperator {
public:
virtual void run() const = 0;
virtual ~GraphOperator() = default;
};

class DispatchableGraphOperator : public GraphOperator {
public:
void run() const;
~GraphOperator();
void run() const override;
~DispatchableGraphOperator() override;

protected:
using run_schema = void (*)(void *);
Expand Down Expand Up @@ -49,7 +54,7 @@ class Graph {
} // namespace infinicore::graph

#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator { \
class __OP_NAME__ : public graph::DispatchableGraphOperator { \
public: \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
Expand Down Expand Up @@ -79,12 +84,12 @@ class Graph {
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);

#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(op); \
} else { \
op->run(); \
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(___op); \
} else { \
___op->run(); \
}

#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
Expand Down
23 changes: 19 additions & 4 deletions include/infinicore/nn/rmsnorm.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "module.hpp"
#include "../ops.hpp"
#include "module.hpp"

namespace infinicore::nn {

Expand Down Expand Up @@ -57,6 +57,21 @@ class RMSNorm : public Module {
*/
Tensor forward(const Tensor &x) const;

/**
* @brief Forward pass: apply RMSNorm in-place with residual
*
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions.
* Will be modified in-place to the normalized output.
* @param residual Residual tensor to add to input before normalization.
* Will be modified in-place to the sum of input and residual.
*
* The normalization is applied over the last dimension.
* For example:
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
* Input: [batch, hidden_size] -> normalize over hidden_size
*/
void forward_inplace(Tensor &x, Tensor &residual) const;

// Module information
size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; }
Expand All @@ -73,9 +88,9 @@ class RMSNorm : public Module {
INFINICORE_NN_PARAMETER(weight);

private:
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
};

} // namespace infinicore::nn
3 changes: 3 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
Expand Down
15 changes: 6 additions & 9 deletions include/infinicore/ops/add.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Add {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor add(Tensor a, Tensor b);
void add_(Tensor c, Tensor a, Tensor b);
Tensor operator+(Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(Add, Tensor, const Tensor &, const Tensor &);

Tensor add(const Tensor &a, const Tensor &b);
void add_(Tensor c, const Tensor &a, const Tensor &b);

} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/add_rms_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
#include <utility>

namespace infinicore::op {
class AddRMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float);

// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
// Fused Add and RMS Normalization (inplace)
// normalized_result wil be stored in input, add_result will be stored in residual
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f);
} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/causal_softmax.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class CausalSoftmax {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor causal_softmax(Tensor input);
void causal_softmax_(Tensor output, Tensor input);
INFINICORE_GRAPH_OP_CLASS(CausalSoftmax, Tensor, const Tensor &);

Tensor causal_softmax(const Tensor &input);
void causal_softmax_(Tensor output, const Tensor &input);

} // namespace infinicore::op
24 changes: 24 additions & 0 deletions include/infinicore/ops/distributed/allreduce.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "../../device.hpp"
#include "../../graph/graph.hpp"
#include "../common/op.hpp"

#include <infiniccl.h>

namespace infinicore::op::distributed {
class AllReduce : public graph::GraphOperator {
public:
AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
~AllReduce();
void run() const override;
static void execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);

private:
void *planned_meta_;
};

Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);

} // namespace infinicore::op::distributed
8 changes: 6 additions & 2 deletions include/infinicore/ops/embedding.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

Tensor embedding(Tensor input, Tensor weight);
void embedding_(Tensor out, Tensor input, Tensor weight);
INFINICORE_GRAPH_OP_CLASS(Embedding, Tensor, const Tensor &, const Tensor &);

Tensor embedding(const Tensor &input, const Tensor &weight);
void embedding_(Tensor out, const Tensor &input, const Tensor &weight);
} // namespace infinicore::op
12 changes: 12 additions & 0 deletions include/infinicore/ops/flash_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool);

Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
} // namespace infinicore::op
6 changes: 3 additions & 3 deletions include/infinicore/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, const Tensor &, const Tensor &, float, float);

Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
Tensor gemm(const Tensor &a, const Tensor &b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta);

} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/kv_caching.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(KVCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);

void kv_caching_(Tensor k_cache,
Tensor v_cache,
const Tensor &k,
const Tensor &v,
const Tensor &past_kv_lengths);
} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/mul.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Mul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor mul(Tensor a, Tensor b);
void mul_(Tensor c, Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &);

Tensor mul(const Tensor &a, const Tensor &b);
void mul_(Tensor c, const Tensor &a, const Tensor &b);

} // namespace infinicore::op
18 changes: 10 additions & 8 deletions include/infinicore/ops/paged_attention.hpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional<Tensor>, float);

Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);

void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);

Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
10 changes: 3 additions & 7 deletions include/infinicore/ops/paged_caching.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);

void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping);

} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/rearrange.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Rearrange {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor y, Tensor x);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor rearrange(Tensor x);
void rearrange_(Tensor y, Tensor x);
INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &);

Tensor rearrange(const Tensor &x);
void rearrange_(Tensor y, const Tensor &x);

} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/rms_norm.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class RMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float);

Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);

} // namespace infinicore::op
Loading