Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
softmax,
sub,
tanh,
maximum,
atan,
batch_norm,
bincount,
adaptive_max_pool2d,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"maximum",
"atan",
"batch_norm",
"bincount",
"adaptive_max_pool2d",
]
76 changes: 76 additions & 0 deletions src/ntops/kernels/adaptive_max_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

def arrangement(input, output, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size):
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, H_in, W_in)
# output: (N, C, H_out, W_out)

# 使用 tile 将输入切分为窗口
# floor_mode=True 对应默认行为,对于 Adaptive Pool,我们通常通过计算好的 stride/kernel 确保覆盖
input_arranged = input.tile(
(1, 1, kernel_size_h, kernel_size_w),
(1, 1, stride_h, stride_w)
)
# => (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w)

input_arranged = input_arranged.ravel()
# => (N, C, H_out, W_out, 1, 1, k_h, k_w)

input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
# => (N*C*H_out*W_out, k_h*k_w)

# 找到最近的 2 的倍数用于并行规约
nearest_pow2 = 1 << (kernel_size_h * kernel_size_w - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., 1), dtype=(1, nearest_pow2)

input_arranged.dtype = input_arranged.dtype.squeeze(0)
# => (..., 1), dtype=(nearest_pow2, )

input_arranged = input_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, nearest_pow2)
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)

# 处理 output 的 layout 以匹配 input 的 block_size
output_arranged = output.tile((1, 1, 1, 1))
output_arranged = output_arranged.ravel()
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
output_arranged = output_arranged.tile((block_size, -1))
output_arranged.dtype = output_arranged.dtype.squeeze(1)

return input_arranged, output_arranged

def application(input, output):
# input: (block_size, nearest_pow2)
# output: (block_size, )

# 简单的 max reduction
# 因为在 premake 中设置了 other=float("-inf"),padding 部分的值为负无穷,
# 或者是 nearest_pow2 补齐产生的部分,通常默认为 0 或 padding 值,
# 这里为了安全,可以显式处理 padding,但如果 arrange padding 正确,直接 max 即可。
# 假设 DSL 的 tile 填充行为遵循 Tensor 的 other 属性。

output = ntl.max(input, axis=1)

def premake(ndim, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size=None, dtype=None):
arrangement_ = functools.partial(
arrangement,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_h=stride_h,
stride_w=stride_w,
block_size=block_size,
)

tensors = (
# input: 设置 other 为负无穷,这样 tile 越界填充的值不会影响 max
Tensor(ndim, dtype=dtype, other=float("-inf")),
Tensor(ndim, dtype=dtype), # output
)

return arrangement_, application, tensors
148 changes: 148 additions & 0 deletions src/ntops/kernels/atan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def _atan_taylor_poly(x, dtype):
"""
计算 atan(x) 的泰勒级数近似。
有效范围:|x| <= 0.42
在此范围内,15阶多项式足以提供 float32/double 级别的精度。
"""
x2 = x * x

# 泰勒展开系数: 1, -1/3, 1/5, -1/7, ...
# 为了精度,我们保留足够多的小数位
c3 = -0.333333333333
c5 = 0.2
c7 = -0.142857142857
c9 = 0.111111111111
c11 = -0.090909090909
c13 = 0.076923076923
c15 = -0.066666666667

# Horner 规则计算: x * (1 + x^2 * (c3 + x^2 * (...)))
p = ntl.cast(c15, dtype)
p = p * x2 + ntl.cast(c13, dtype)
p = p * x2 + ntl.cast(c11, dtype)
p = p * x2 + ntl.cast(c9, dtype)
p = p * x2 + ntl.cast(c7, dtype)
p = p * x2 + ntl.cast(c5, dtype)
p = p * x2 + ntl.cast(c3, dtype)

# result = x + x^3 * p = x * (1 + x^2 * p)
# 提公因式 x 以减少一次乘法并提高小数值的稳定性
return x + x * x2 * p


def _atan(x, dtype):
"""
高精度数值稳定的反正切计算。
使用两级范围归约策略将输入映射到小区间,以保证多项式精度。
"""
calc_dtype = dtype if dtype != ntl.float16 else ntl.float32

# === 常量定义 (局部定义以避开作用域问题) ===
PI_OVER_2 = 1.5707963267948966
PI_OVER_4 = 0.7853981633974483
TAN_PI_8 = 0.4142135623730950 # tan(pi/8)

x_arg = ntl.cast(x, calc_dtype)

# 0. 提取符号并取绝对值
# atan(-x) = -atan(x)
sign = ntl.where(x_arg < 0.0, -1.0, 1.0)
abs_x = ntl.abs(x_arg)

# 1. 第一级归约:处理 |x| > 1 的情况
# 使用恒等式:atan(x) = pi/2 - atan(1/x) (当 x > 0)
# 如果 x > 1:
# val_1 = 1/x
# offset_1 = pi/2
# coef_1 = -1
# 否则:
# val_1 = x
# offset_1 = 0
# coef_1 = 1
# 当前结果表达式: offset_1 + coef_1 * atan(val_1)
mask_gt_1 = abs_x > 1.0

# 安全除法:防止 abs_x 为 0 时的除零错误(虽然此时分支不被选择)
safe_abs_x = ntl.where(mask_gt_1, abs_x, ntl.cast(1.0, calc_dtype))

val_1 = ntl.where(mask_gt_1, ntl.cast(1.0, calc_dtype) / safe_abs_x, abs_x)
offset_1 = ntl.where(mask_gt_1, PI_OVER_2, 0.0)
coef_1 = ntl.where(mask_gt_1, -1.0, 1.0)

# 此时 val_1 在 [0, 1] 范围内

# 2. 第二级归约:处理 x 接近 1 的情况
# 使用恒等式:atan(x) = pi/4 + atan((x-1)/(x+1))
# 阈值取 tan(pi/8) ≈ 0.414,这样归约后的值域在 [-0.29, 0.414] 之间
# 这对泰勒级数收敛非常有利
mask_gt_tan_pi_8 = val_1 > TAN_PI_8

# 计算 (x-1)/(x+1)
# 注意:val_1 均为非负数,分母 val_1 + 1 永远 >= 1,无除零风险
reduced_val = (val_1 - 1.0) / (val_1 + 1.0)

val_2 = ntl.where(mask_gt_tan_pi_8, reduced_val, val_1)
offset_2 = ntl.where(mask_gt_tan_pi_8, PI_OVER_4, 0.0)

# 当前结果表达式: atan(val_1) = offset_2 + atan(val_2)
# 此时 |val_2| <= 0.4142...

# 3. 多项式计算
poly_res = _atan_taylor_poly(val_2, calc_dtype)

# 4. 结果组合
# result = sign * (offset_1 + coef_1 * (offset_2 + poly_res))

# 先计算内层: atan_val_1
atan_val_1 = ntl.cast(offset_2, calc_dtype) + poly_res

# 再计算外层: abs_result
abs_result = ntl.cast(offset_1, calc_dtype) + ntl.cast(coef_1, calc_dtype) * atan_val_1

# 最后恢复符号
final_result = ntl.cast(sign, calc_dtype) * abs_result

return ntl.cast(final_result, dtype)


def application(input, output):
"""
计算反正切函数 atan(x)

参数:
input: 输入张量,形状为 (C // block_size, block_size)
output: 输出张量,形状为 (C // block_size, block_size)
"""
dtype = output.dtype.dtype

for i in range(input.shape[0]):
# 获取当前块的数据
input_block = ntl.cast(input[i], dtype)

# 计算 atan(x)
result = _atan(input_block, dtype)

# 将结果存入输出
output[i] = result


def premake(ndim, dim, dtype=None, block_size=None):
"""
准备 atan 内核
"""
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype), # 输入张量
Tensor(ndim, dtype=dtype), # 输出张量
)

return arrangement_, application, tensors
52 changes: 52 additions & 0 deletions src/ntops/kernels/batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import functools
import math

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, weight, bias, eps, output, num_normalized_elements):
# 使用 E[x^2] - E[x]^2 公式计算方差,避免显式处理 Padding Mask
# 因为 Padding 处 input 为 0,0 的平方也是 0,不会污染 sum 和 sum_sq

_sum = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
_sum_sq = ntl.zeros(input.dtype.shape, dtype=ntl.float32)

# Pass 1: 计算 Sum 和 Sum of Squares
for i in range(input.shape[0]):
val = ntl.cast(input[i], ntl.float32)
_sum += val
_sum_sq += val * val

mean = ntl.sum(_sum, 0) / num_normalized_elements
mean_sq = ntl.sum(_sum_sq, 0) / num_normalized_elements

# Var = E[x^2] - (E[x])^2
var = mean_sq - mean * mean
# 确保方差非负 (处理数值误差)
var = ntl.maximum(var, 0.0)

std = ntl.sqrt(var + eps)

# Pass 2: 归一化并输出
# 这里的 store 操作通常会被编译器根据 Tensor 形状自动 Mask 掉越界部分
for i in range(input.shape[0]):
output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight[i] + bias[i]


def premake(ndim, reduction_dims, num_elements, dtype=None, block_size=None):
# reduction_dims 指定了需要在哪些维度上进行规约
arrangement_ = functools.partial(arrangement, dim=reduction_dims, block_size=block_size)

tensors = (
Tensor(ndim, other=0, dtype=dtype), # Input (other=0 确保 padding 读入 0)
Tensor(ndim, dtype=dtype), # Weight
Tensor(ndim, dtype=dtype), # Bias
Tensor(0, dtype=dtype), # eps
Tensor(ndim, dtype=dtype), # Output
Tensor(0, dtype=dtype, constexpr=True, value=num_elements),
)

return arrangement_, application, tensors
86 changes: 86 additions & 0 deletions src/ntops/kernels/bincount.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import functools
from ninetoothed import Tensor
import ninetoothed.language as ntl

def arrangement(input, weights, output, bin_ids, T, S, T_pow2, S_pow2, block_size=None):
# input: (T,)
# weights: (T,)
# output: (S,) - 这是我们要写入的结果
# bin_ids: (S,) - 这是辅助索引,对应 output 的每个位置

# 1. 对 Output 进行分块并行 (Grid 维度)
# output_tiled: (GridSize, block_size)
output_tiled = output.tile((block_size,))

# bin_ids 也随 output 一起分块,以便我们在 Kernel 中知道当前处理的是哪些 bin
bin_ids_tiled = bin_ids.tile((block_size,))

# 2. Input 和 Weights 需要被所有 Block 访问
# 我们先将它们扩展到 GridSize,然后调整维度以符合 T_pow2 的要求
grid_size = output_tiled.shape[0]

# input: (T,) -> (GridSize, T) -> (GridSize, 1, T_pow2) (通过 tile dim 1)
# 注意:这里的 tile((1, T_pow2)) 是为了让 ninetoothed 框架正确处理内部维度
input_expand = input.unsqueeze(0).expand((grid_size, -1))
input_tiled = input_expand.tile((1, T_pow2)).squeeze(1)

weights_expand = weights.unsqueeze(0).expand((grid_size, -1))
weights_tiled = weights_expand.tile((1, T_pow2)).squeeze(1)

return input_tiled, weights_tiled, output_tiled, bin_ids_tiled, T, S

def application(input, weights, output, bin_ids, T, S):
# input: (1, T_pow2) <-- 来自 arrangement 的广播
# weights: (1, T_pow2)
# output: (block_size,) <-- 当前 block 负责的输出片段
# bin_ids: (block_size,) <-- 当前 block 负责的 bin 索引

# 1. 准备维度以便广播比较
# 我们要计算矩阵: Match[i, j] = (bin_ids[i] == input[j])
# bin_ids: (block_size, 1)
# input: (1, T_pow2)

bin_ids_col = ntl.expand_dims(bin_ids, 1) # (block_size, 1)

# 确保 input 和 weights 在 dim 0 上广播以匹配 block_size
input_b = ntl.broadcast_to(input, (bin_ids.shape[0], input.shape[1])) # (block_size, T_pow2)
weights_b = ntl.broadcast_to(weights, (bin_ids.shape[0], weights.shape[1])) # (block_size, T_pow2)

# 2. 生成有效性掩码 (处理 Padding)
# input 的真实长度是 T,T_pow2 之外的是 padding
col_indices = ntl.arange(0, input.shape[1]) # (T_pow2,)
col_valid = col_indices < T # (T_pow2,)

# 广播 mask
col_valid_b = ntl.expand_dims(col_valid, 0)
col_valid_b = ntl.broadcast_to(col_valid_b, (bin_ids.shape[0], input.shape[1]))

# 3. 核心计算:Masking + Sum
# 找出哪些 input 值落入了当前 block 负责的 bin 中
match_mask = (input_b == bin_ids_col)

# 结合有效性检查
final_mask = ntl.where(col_valid_b, match_mask, False)

# 选择权重 (如果 weights 是全 1,则相当于计数)
selected = ntl.where(final_mask, weights_b, 0.0)

# 沿着 T 维度求和,得到每个 bin 的总值
result = ntl.sum(selected, axis=1) # (block_size,)

# 4. 写回
output = result

def premake(T_pow2, S_pow2, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, T_pow2=T_pow2, S_pow2=S_pow2, block_size=block_size)

tensors = (
Tensor(1, dtype=int, shape_options={'constexpr': True}), # input
Tensor(1, dtype=dtype, shape_options={'constexpr': True}), # weights
Tensor(1, dtype=dtype, shape_options={'constexpr': True}), # output
Tensor(1, dtype=int, shape_options={'constexpr': True}), # bin_ids
Tensor(0, dtype=int, constexpr=True), # T
Tensor(0, dtype=int, constexpr=True), # S
)

return arrangement_, application, tensors
Loading