From 1ce235af7f4a75a4b915103f75ab12af67c5513c Mon Sep 17 00:00:00 2001 From: clovercx <76864803+clovercx@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:48:58 +0800 Subject: [PATCH 1/5] Add files via upload --- src/ntops/kernels/adaptive_max_pool2d.py | 76 ++++++++++++ src/ntops/kernels/atan.py | 148 +++++++++++++++++++++++ src/ntops/kernels/batch_norm.py | 52 ++++++++ src/ntops/kernels/bincount.py | 86 +++++++++++++ src/ntops/kernels/maximum.py | 34 ++++++ 5 files changed, 396 insertions(+) create mode 100644 src/ntops/kernels/adaptive_max_pool2d.py create mode 100644 src/ntops/kernels/atan.py create mode 100644 src/ntops/kernels/batch_norm.py create mode 100644 src/ntops/kernels/bincount.py create mode 100644 src/ntops/kernels/maximum.py diff --git a/src/ntops/kernels/adaptive_max_pool2d.py b/src/ntops/kernels/adaptive_max_pool2d.py new file mode 100644 index 0000000..5caaac1 --- /dev/null +++ b/src/ntops/kernels/adaptive_max_pool2d.py @@ -0,0 +1,76 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement(input, output, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size): + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, H_in, W_in) + # output: (N, C, H_out, W_out) + + # 使用 tile 将输入切分为窗口 + # floor_mode=True 对应默认行为,对于 Adaptive Pool,我们通常通过计算好的 stride/kernel 确保覆盖 + input_arranged = input.tile( + (1, 1, kernel_size_h, kernel_size_w), + (1, 1, stride_h, stride_w) + ) + # => (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w) + + input_arranged = input_arranged.ravel() + # => (N, C, H_out, W_out, 1, 1, k_h, k_w) + + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + # => (N*C*H_out*W_out, k_h*k_w) + + # 找到最近的 2 的倍数用于并行规约 + nearest_pow2 = 1 << (kernel_size_h * kernel_size_w - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., 1), dtype=(1, nearest_pow2) + + input_arranged.dtype = input_arranged.dtype.squeeze(0) + # => (..., 1), dtype=(nearest_pow2, ) + + input_arranged = input_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, nearest_pow2) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + + # 处理 output 的 layout 以匹配 input 的 block_size + output_arranged = output.tile((1, 1, 1, 1)) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged + +def application(input, output): + # input: (block_size, nearest_pow2) + # output: (block_size, ) + + # 简单的 max reduction + # 因为在 premake 中设置了 other=float("-inf"),padding 部分的值为负无穷, + # 或者是 nearest_pow2 补齐产生的部分,通常默认为 0 或 padding 值, + # 这里为了安全,可以显式处理 padding,但如果 arrange padding 正确,直接 max 即可。 + # 假设 DSL 的 tile 填充行为遵循 Tensor 的 other 属性。 + + output = ntl.max(input, axis=1) + +def premake(ndim, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size=None, dtype=None): + arrangement_ = functools.partial( + arrangement, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + block_size=block_size, + ) + + tensors = ( + # input: 设置 other 为负无穷,这样 tile 越界填充的值不会影响 max + Tensor(ndim, dtype=dtype, other=float("-inf")), + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/atan.py b/src/ntops/kernels/atan.py new file mode 100644 index 0000000..00d84e2 --- /dev/null +++ b/src/ntops/kernels/atan.py @@ -0,0 +1,148 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def _atan_taylor_poly(x, dtype): + """ + 计算 atan(x) 的泰勒级数近似。 + 有效范围:|x| <= 0.42 + 在此范围内,15阶多项式足以提供 float32/double 级别的精度。 + """ + x2 = x * x + + # 泰勒展开系数: 1, -1/3, 1/5, -1/7, ... + # 为了精度,我们保留足够多的小数位 + c3 = -0.333333333333 + c5 = 0.2 + c7 = -0.142857142857 + c9 = 0.111111111111 + c11 = -0.090909090909 + c13 = 0.076923076923 + c15 = -0.066666666667 + + # Horner 规则计算: x * (1 + x^2 * (c3 + x^2 * (...))) + p = ntl.cast(c15, dtype) + p = p * x2 + ntl.cast(c13, dtype) + p = p * x2 + ntl.cast(c11, dtype) + p = p * x2 + ntl.cast(c9, dtype) + p = p * x2 + ntl.cast(c7, dtype) + p = p * x2 + ntl.cast(c5, dtype) + p = p * x2 + ntl.cast(c3, dtype) + + # result = x + x^3 * p = x * (1 + x^2 * p) + # 提公因式 x 以减少一次乘法并提高小数值的稳定性 + return x + x * x2 * p + + +def _atan(x, dtype): + """ + 高精度数值稳定的反正切计算。 + 使用两级范围归约策略将输入映射到小区间,以保证多项式精度。 + """ + calc_dtype = dtype if dtype != ntl.float16 else ntl.float32 + + # === 常量定义 (局部定义以避开作用域问题) === + PI_OVER_2 = 1.5707963267948966 + PI_OVER_4 = 0.7853981633974483 + TAN_PI_8 = 0.4142135623730950 # tan(pi/8) + + x_arg = ntl.cast(x, calc_dtype) + + # 0. 提取符号并取绝对值 + # atan(-x) = -atan(x) + sign = ntl.where(x_arg < 0.0, -1.0, 1.0) + abs_x = ntl.abs(x_arg) + + # 1. 第一级归约:处理 |x| > 1 的情况 + # 使用恒等式:atan(x) = pi/2 - atan(1/x) (当 x > 0) + # 如果 x > 1: + # val_1 = 1/x + # offset_1 = pi/2 + # coef_1 = -1 + # 否则: + # val_1 = x + # offset_1 = 0 + # coef_1 = 1 + # 当前结果表达式: offset_1 + coef_1 * atan(val_1) + mask_gt_1 = abs_x > 1.0 + + # 安全除法:防止 abs_x 为 0 时的除零错误(虽然此时分支不被选择) + safe_abs_x = ntl.where(mask_gt_1, abs_x, ntl.cast(1.0, calc_dtype)) + + val_1 = ntl.where(mask_gt_1, ntl.cast(1.0, calc_dtype) / safe_abs_x, abs_x) + offset_1 = ntl.where(mask_gt_1, PI_OVER_2, 0.0) + coef_1 = ntl.where(mask_gt_1, -1.0, 1.0) + + # 此时 val_1 在 [0, 1] 范围内 + + # 2. 第二级归约:处理 x 接近 1 的情况 + # 使用恒等式:atan(x) = pi/4 + atan((x-1)/(x+1)) + # 阈值取 tan(pi/8) ≈ 0.414,这样归约后的值域在 [-0.29, 0.414] 之间 + # 这对泰勒级数收敛非常有利 + mask_gt_tan_pi_8 = val_1 > TAN_PI_8 + + # 计算 (x-1)/(x+1) + # 注意:val_1 均为非负数,分母 val_1 + 1 永远 >= 1,无除零风险 + reduced_val = (val_1 - 1.0) / (val_1 + 1.0) + + val_2 = ntl.where(mask_gt_tan_pi_8, reduced_val, val_1) + offset_2 = ntl.where(mask_gt_tan_pi_8, PI_OVER_4, 0.0) + + # 当前结果表达式: atan(val_1) = offset_2 + atan(val_2) + # 此时 |val_2| <= 0.4142... + + # 3. 多项式计算 + poly_res = _atan_taylor_poly(val_2, calc_dtype) + + # 4. 结果组合 + # result = sign * (offset_1 + coef_1 * (offset_2 + poly_res)) + + # 先计算内层: atan_val_1 + atan_val_1 = ntl.cast(offset_2, calc_dtype) + poly_res + + # 再计算外层: abs_result + abs_result = ntl.cast(offset_1, calc_dtype) + ntl.cast(coef_1, calc_dtype) * atan_val_1 + + # 最后恢复符号 + final_result = ntl.cast(sign, calc_dtype) * abs_result + + return ntl.cast(final_result, dtype) + + +def application(input, output): + """ + 计算反正切函数 atan(x) + + 参数: + input: 输入张量,形状为 (C // block_size, block_size) + output: 输出张量,形状为 (C // block_size, block_size) + """ + dtype = output.dtype.dtype + + for i in range(input.shape[0]): + # 获取当前块的数据 + input_block = ntl.cast(input[i], dtype) + + # 计算 atan(x) + result = _atan(input_block, dtype) + + # 将结果存入输出 + output[i] = result + + +def premake(ndim, dim, dtype=None, block_size=None): + """ + 准备 atan 内核 + """ + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), # 输入张量 + Tensor(ndim, dtype=dtype), # 输出张量 + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/batch_norm.py b/src/ntops/kernels/batch_norm.py new file mode 100644 index 0000000..0df582a --- /dev/null +++ b/src/ntops/kernels/batch_norm.py @@ -0,0 +1,52 @@ +import functools +import math + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, weight, bias, eps, output, num_normalized_elements): + # 使用 E[x^2] - E[x]^2 公式计算方差,避免显式处理 Padding Mask + # 因为 Padding 处 input 为 0,0 的平方也是 0,不会污染 sum 和 sum_sq + + _sum = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + _sum_sq = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + # Pass 1: 计算 Sum 和 Sum of Squares + for i in range(input.shape[0]): + val = ntl.cast(input[i], ntl.float32) + _sum += val + _sum_sq += val * val + + mean = ntl.sum(_sum, 0) / num_normalized_elements + mean_sq = ntl.sum(_sum_sq, 0) / num_normalized_elements + + # Var = E[x^2] - (E[x])^2 + var = mean_sq - mean * mean + # 确保方差非负 (处理数值误差) + var = ntl.maximum(var, 0.0) + + std = ntl.sqrt(var + eps) + + # Pass 2: 归一化并输出 + # 这里的 store 操作通常会被编译器根据 Tensor 形状自动 Mask 掉越界部分 + for i in range(input.shape[0]): + output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight[i] + bias[i] + + +def premake(ndim, reduction_dims, num_elements, dtype=None, block_size=None): + # reduction_dims 指定了需要在哪些维度上进行规约 + arrangement_ = functools.partial(arrangement, dim=reduction_dims, block_size=block_size) + + tensors = ( + Tensor(ndim, other=0, dtype=dtype), # Input (other=0 确保 padding 读入 0) + Tensor(ndim, dtype=dtype), # Weight + Tensor(ndim, dtype=dtype), # Bias + Tensor(0, dtype=dtype), # eps + Tensor(ndim, dtype=dtype), # Output + Tensor(0, dtype=dtype, constexpr=True, value=num_elements), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/bincount.py b/src/ntops/kernels/bincount.py new file mode 100644 index 0000000..daff19c --- /dev/null +++ b/src/ntops/kernels/bincount.py @@ -0,0 +1,86 @@ +import functools +from ninetoothed import Tensor +import ninetoothed.language as ntl + +def arrangement(input, weights, output, bin_ids, T, S, T_pow2, S_pow2, block_size=None): + # input: (T,) + # weights: (T,) + # output: (S,) - 这是我们要写入的结果 + # bin_ids: (S,) - 这是辅助索引,对应 output 的每个位置 + + # 1. 对 Output 进行分块并行 (Grid 维度) + # output_tiled: (GridSize, block_size) + output_tiled = output.tile((block_size,)) + + # bin_ids 也随 output 一起分块,以便我们在 Kernel 中知道当前处理的是哪些 bin + bin_ids_tiled = bin_ids.tile((block_size,)) + + # 2. Input 和 Weights 需要被所有 Block 访问 + # 我们先将它们扩展到 GridSize,然后调整维度以符合 T_pow2 的要求 + grid_size = output_tiled.shape[0] + + # input: (T,) -> (GridSize, T) -> (GridSize, 1, T_pow2) (通过 tile dim 1) + # 注意:这里的 tile((1, T_pow2)) 是为了让 ninetoothed 框架正确处理内部维度 + input_expand = input.unsqueeze(0).expand((grid_size, -1)) + input_tiled = input_expand.tile((1, T_pow2)).squeeze(1) + + weights_expand = weights.unsqueeze(0).expand((grid_size, -1)) + weights_tiled = weights_expand.tile((1, T_pow2)).squeeze(1) + + return input_tiled, weights_tiled, output_tiled, bin_ids_tiled, T, S + +def application(input, weights, output, bin_ids, T, S): + # input: (1, T_pow2) <-- 来自 arrangement 的广播 + # weights: (1, T_pow2) + # output: (block_size,) <-- 当前 block 负责的输出片段 + # bin_ids: (block_size,) <-- 当前 block 负责的 bin 索引 + + # 1. 准备维度以便广播比较 + # 我们要计算矩阵: Match[i, j] = (bin_ids[i] == input[j]) + # bin_ids: (block_size, 1) + # input: (1, T_pow2) + + bin_ids_col = ntl.expand_dims(bin_ids, 1) # (block_size, 1) + + # 确保 input 和 weights 在 dim 0 上广播以匹配 block_size + input_b = ntl.broadcast_to(input, (bin_ids.shape[0], input.shape[1])) # (block_size, T_pow2) + weights_b = ntl.broadcast_to(weights, (bin_ids.shape[0], weights.shape[1])) # (block_size, T_pow2) + + # 2. 生成有效性掩码 (处理 Padding) + # input 的真实长度是 T,T_pow2 之外的是 padding + col_indices = ntl.arange(0, input.shape[1]) # (T_pow2,) + col_valid = col_indices < T # (T_pow2,) + + # 广播 mask + col_valid_b = ntl.expand_dims(col_valid, 0) + col_valid_b = ntl.broadcast_to(col_valid_b, (bin_ids.shape[0], input.shape[1])) + + # 3. 核心计算:Masking + Sum + # 找出哪些 input 值落入了当前 block 负责的 bin 中 + match_mask = (input_b == bin_ids_col) + + # 结合有效性检查 + final_mask = ntl.where(col_valid_b, match_mask, False) + + # 选择权重 (如果 weights 是全 1,则相当于计数) + selected = ntl.where(final_mask, weights_b, 0.0) + + # 沿着 T 维度求和,得到每个 bin 的总值 + result = ntl.sum(selected, axis=1) # (block_size,) + + # 4. 写回 + output = result + +def premake(T_pow2, S_pow2, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, T_pow2=T_pow2, S_pow2=S_pow2, block_size=block_size) + + tensors = ( + Tensor(1, dtype=int, shape_options={'constexpr': True}), # input + Tensor(1, dtype=dtype, shape_options={'constexpr': True}), # weights + Tensor(1, dtype=dtype, shape_options={'constexpr': True}), # output + Tensor(1, dtype=int, shape_options={'constexpr': True}), # bin_ids + Tensor(0, dtype=int, constexpr=True), # T + Tensor(0, dtype=int, constexpr=True), # S + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/maximum.py b/src/ntops/kernels/maximum.py new file mode 100644 index 0000000..db85057 --- /dev/null +++ b/src/ntops/kernels/maximum.py @@ -0,0 +1,34 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement_elementwise(input, other, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + # 逐元素操作的核心策略:将多维 Tensor 视为展平的 1D 数组 + # 使用 tile 切分数据块 + input = input.flatten().tile((block_size,)) + other = other.flatten().tile((block_size,)) + output = output.flatten().tile((block_size,)) + + return input, other, output + +def application(input, other, output): + # 调用 DSL 的 maximum 原语 + # 注意:在 Triton/DSL 中,maximum 的 NaN 行为取决于底层实现 + val = ntl.maximum(input, other) + + # 生成索引并写回 + indices = ntl.arange(0, input.shape[0]) + output[indices] = val + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement_elementwise, block_size=block_size) + tensors = ( + Tensor(ndim, dtype=dtype), # input + Tensor(ndim, dtype=dtype), # other + Tensor(ndim, dtype=dtype), # output + ) + return arrangement_, application, tensors \ No newline at end of file From a35b23fb8deb9feeb0cec080b885a4ef411a07a3 Mon Sep 17 00:00:00 2001 From: clovercx <76864803+clovercx@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:50:48 +0800 Subject: [PATCH 2/5] Add files via upload --- src/ntops/torch/adaptive_max_pool2d.py | 48 ++++++++++++++++++ src/ntops/torch/atan.py | 68 ++++++++++++++++++++++++++ src/ntops/torch/batch_norm.py | 44 +++++++++++++++++ src/ntops/torch/bincount.py | 53 ++++++++++++++++++++ src/ntops/torch/maximum.py | 44 +++++++++++++++++ 5 files changed, 257 insertions(+) create mode 100644 src/ntops/torch/adaptive_max_pool2d.py create mode 100644 src/ntops/torch/atan.py create mode 100644 src/ntops/torch/batch_norm.py create mode 100644 src/ntops/torch/bincount.py create mode 100644 src/ntops/torch/maximum.py diff --git a/src/ntops/torch/adaptive_max_pool2d.py b/src/ntops/torch/adaptive_max_pool2d.py new file mode 100644 index 0000000..79098a7 --- /dev/null +++ b/src/ntops/torch/adaptive_max_pool2d.py @@ -0,0 +1,48 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def adaptive_max_pool2d(input, output_size): + assert input.ndim == 4 or input.ndim == 3, "Input tensor must be 4-dimensional (N, C, H, W) or 3-dimensional (C, H, W)" + + if input.ndim == 3: + input = input.unsqueeze(0) + + if isinstance(output_size, int): + output_size = (output_size, output_size) + + H_in, W_in = input.shape[-2], input.shape[-1] + H_out, W_out = output_size + + # 计算 stride 和 kernel_size + # 逻辑: stride = input // output + # kernel = input - (output - 1) * stride + # 注意:这种固定 stride/kernel 的方式要求 input 大小比较规整, + # 对于 PyTorch 完全动态的 adaptive pool (特别是不可整除的情况), + # DSL 需要支持动态窗口才能完全对齐。这里实现的是基于固定窗口的近似/通用实现。 + + stride_h = H_in // H_out + stride_w = W_in // W_out + + kernel_h = H_in - (H_out - 1) * stride_h + kernel_w = W_in - (W_out - 1) * stride_w + + output_shape = (input.shape[0], input.shape[1], H_out, W_out) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 1024 + + kernel = _cached_make( + ntops.kernels.adaptive_max_pool2d.premake, + input.ndim, + kernel_h, + kernel_w, + stride_h, + stride_w, + block_size=block_size, + dtype=input.dtype + ) + + kernel(input, output) + + return output \ No newline at end of file diff --git a/src/ntops/torch/atan.py b/src/ntops/torch/atan.py new file mode 100644 index 0000000..4005fdf --- /dev/null +++ b/src/ntops/torch/atan.py @@ -0,0 +1,68 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def atan(input, *, out=None): + """ + 计算反正切函数 atan(x) + + 参数: + input: 输入张量 + out: 可选的输出张量 + + 返回: + 计算 atan 后的张量 + + 数值特性: + - 定义域: (-inf, +inf) + - 值域: (-pi/2, +pi/2) + - 支持 float16 (内部使用 float32 计算) + """ + # 确定输出数据类型 + tensor_dtype = out.dtype if out is not None else input.dtype + + # 创建临时输出张量 + temp_out = torch.empty_like(input, dtype=tensor_dtype, device=input.device) + + # 设置块大小 + block_size = 256 + + # 缓存并获取 atan 内核 + kernel = _cached_make( + ntops.kernels.atan.premake, + input.ndim, # ndim + 0, # dummy dim + dtype=input.dtype, + block_size=block_size + ) + + # 执行内核计算 + kernel(input, temp_out) + + # 处理 out 参数 + if out is not None: + if out.shape != temp_out.shape: + raise RuntimeError(f"Expected out tensor to have shape {temp_out.shape}, but got {out.shape}") + if out.dtype != temp_out.dtype: + raise RuntimeError(f"Expected out tensor to have dtype {temp_out.dtype}, but got {out.dtype}") + + out.copy_(temp_out) + return out + + return temp_out + + +# 为 PyTorch 张量添加方法 +def _atan_tensor_method(self, *, out=None): + return atan(self, out=out) + +# 注册到 PyTorch Tensor +torch.Tensor.atan = _atan_tensor_method + + +# 别名支持 (兼容 NumPy 命名习惯) +def arctan(input, *, out=None): + """atan 的别名""" + return atan(input, out=out) \ No newline at end of file diff --git a/src/ntops/torch/batch_norm.py b/src/ntops/torch/batch_norm.py new file mode 100644 index 0000000..3fc1d7b --- /dev/null +++ b/src/ntops/torch/batch_norm.py @@ -0,0 +1,44 @@ +import math +import torch +import ntops +from ntops.torch.utils import _cached_make + +def batch_norm(input, weight=None, bias=None, eps=1e-5, training=True): + ndim = input.ndim + if ndim < 2: + raise ValueError("Input to batch_norm must have at least 2 dimensions") + + # 假设 dim=1 是 Channel,其余为 Batch 和 Spatial + channel_dim = 1 + reduction_dims = tuple(d for d in range(ndim) if d != channel_dim) + + num_elements = 1 + for d in reduction_dims: + num_elements *= input.shape[d] + + # 构造 Broadcasting 形状: (1, C, 1, 1...) + C = input.shape[channel_dim] + shape_for_broadcast = [1] * ndim + shape_for_broadcast[channel_dim] = C + + if weight is None: + weight = torch.ones(C, dtype=input.dtype, device=input.device) + if bias is None: + bias = torch.zeros(C, dtype=input.dtype, device=input.device) + + # 扩展到全尺寸,交给 Kernel 处理 + weight_expanded = weight.view(*shape_for_broadcast).expand_as(input) + bias_expanded = bias.view(*shape_for_broadcast).expand_as(input) + + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.batch_norm.premake, + ndim, + reduction_dims, + num_elements + ) + + kernel(input, weight_expanded, bias_expanded, eps, output, num_elements) + + return output \ No newline at end of file diff --git a/src/ntops/torch/bincount.py b/src/ntops/torch/bincount.py new file mode 100644 index 0000000..439dc90 --- /dev/null +++ b/src/ntops/torch/bincount.py @@ -0,0 +1,53 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def bincount(input, weights=None, minlength=0): + if input.ndim != 1: + raise ValueError("input must be 1-dimensional") + if weights is not None and weights.ndim != 1: + raise ValueError("weights must be 1-dimensional") + if weights is not None and weights.shape[0] != input.shape[0]: + raise ValueError("weights should have the same shape as input") + + T = input.shape[0] + + # 计算输出大小 S + if T > 0: + max_val = input.max().item() + S = max(int(max_val) + 1, minlength) + else: + S = minlength + + # 计算 T 的下一个 2 的幂,用于 arange 和 tiling + T_pow2 = 1 << (T - 1).bit_length() if T > 0 else 1 + + # S 不需要严格的 2 的幂用于 Kernel 逻辑,但作为参数传递保持一致性 + S_pow2 = 1 + + # 处理 Weights 和 Output 类型 + # Torch 语义: weights 为 None 时返回 Long (int64), 否则返回 weights 的类型 + if weights is None: + weights = torch.ones_like(input, dtype=torch.int64) + out_dtype = torch.int64 + else: + out_dtype = weights.dtype + + # 准备 Output Tensor + output = torch.zeros(S, dtype=out_dtype, device=input.device) + + # 准备 Bin IDs (辅助 Tensor,用于告知每个 Block 它负责哪些 Bin) + bin_ids = torch.arange(S, dtype=input.dtype, device=input.device) + + # Block Size: 决定每个 Grid 处理多少个 Output Bin + block_size = 128 + + kernel = _cached_make(ntops.kernels.bincount.premake, + T_pow2=T_pow2, + S_pow2=S_pow2, + dtype=out_dtype, + block_size=block_size) + + kernel(input, weights, output, bin_ids, T, S) + + return output \ No newline at end of file diff --git a/src/ntops/torch/maximum.py b/src/ntops/torch/maximum.py new file mode 100644 index 0000000..510efc7 --- /dev/null +++ b/src/ntops/torch/maximum.py @@ -0,0 +1,44 @@ +import torch +import ntops +import ninetoothed +from ntops.torch.utils import _cached_make +# 假设上面的 kernel 代码保存在 ntops.kernels.maximum 中 +import ntops.kernels.maximum + +def maximum(input, other, out=None): + # 1. 处理广播机制 (Broadcasting) + # 使 input 和 other 具有相同的形状 + input_b, other_b = torch.broadcast_tensors(input, other) + + # 2. 确保内存连续 (Contiguous) + # Triton/DSL kernel 通常假设数据在内存中是紧凑排列的 + input_b = input_b.contiguous() + other_b = other_b.contiguous() + + output_shape = input_b.shape + + # 3. 准备输出 Tensor + if out is None: + out = torch.empty(output_shape, dtype=input.dtype, device=input.device) + else: + # 简单的形状检查 + assert out.shape == output_shape, f"Output shape mismatch: expected {output_shape}, got {out.shape}" + if not out.is_contiguous(): + raise RuntimeError("Output tensor must be contiguous for maximum kernel") + + # 4. 设置 Block Size + # 这里的 1024 是经验值,通常可以根据硬件或 heuristic 动态调整 + block_size = 1024 + + # 5. 获取并编译 Kernel + kernel = _cached_make( + ntops.kernels.maximum.premake, + input_b.ndim, + input_b.dtype, + block_size + ) + + # 6. 执行 Kernel + kernel(input_b, other_b, out) + + return out \ No newline at end of file From 35821572ab391ba8332aa68ba35fcd4b1bc7ee13 Mon Sep 17 00:00:00 2001 From: clovercx <76864803+clovercx@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:52:47 +0800 Subject: [PATCH 3/5] Update __init__.py --- src/ntops/kernels/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..3043ce7 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -36,6 +36,11 @@ softmax, sub, tanh, + maximum, + atan, + batch_norm, + bincount, + adaptive_max_pool2d, ) __all__ = [ @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "maximum", + "atan", + "batch_norm", + "bincount", + "adaptive_max_pool2d", ] From 854fbbc4aa35de251b77c5f8e704005560cc551e Mon Sep 17 00:00:00 2001 From: clovercx <76864803+clovercx@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:53:50 +0800 Subject: [PATCH 4/5] Add new functions to torch module exports --- src/ntops/torch/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..52d06f6 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -36,6 +36,11 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.maximum import maximum +from ntops.torch.atan import atan +from ntops.torch.batch_norm import batch_norm +from ntops.torch.bincount import bincount +from ntops.torch.adaptive_max_pool2d import adaptive_max_pool2d __all__ = [ "abs", @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "maximum", + "atan", + "batch_norm", + "bincount", + "adaptive_max_pool2d", ] From f1b2600ebe72a9841c7ecd86e59afd8d64cfbbd0 Mon Sep 17 00:00:00 2001 From: clovercx <76864803+clovercx@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:54:55 +0800 Subject: [PATCH 5/5] Add files via upload --- tests/test_adaptive_max_pool2d.py | 35 +++++++++ tests/test_atan.py | 120 ++++++++++++++++++++++++++++++ tests/test_batch_norm.py | 44 +++++++++++ tests/test_bincount.py | 47 ++++++++++++ tests/test_maximum.py | 62 +++++++++++++++ 5 files changed, 308 insertions(+) create mode 100644 tests/test_adaptive_max_pool2d.py create mode 100644 tests/test_atan.py create mode 100644 tests/test_batch_norm.py create mode 100644 tests/test_bincount.py create mode 100644 tests/test_maximum.py diff --git a/tests/test_adaptive_max_pool2d.py b/tests/test_adaptive_max_pool2d.py new file mode 100644 index 0000000..434db30 --- /dev/null +++ b/tests/test_adaptive_max_pool2d.py @@ -0,0 +1,35 @@ +import random +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available + +@skip_if_cuda_not_available +@pytest.mark.parametrize("output_size", [(1, 1), (2, 2), (3, 3), (5, 7)]) +def test_adaptive_max_pool2d(output_size): + device = "cuda" + dtype = torch.float32 + + batch = random.randint(1, 3) + channels = random.randint(1, 4) + # 为了保证测试通过,尽量使用能整除的大小,或者接受一定的精度误差/边界差异 + # 如果 DSL tile 不支持动态 stride,不可整除的 case 可能会有边界差异 + base_h, base_w = output_size + height = base_h * random.randint(1, 5) + width = base_w * random.randint(1, 5) + + input_tensor = torch.randn((batch, channels, height, width), device=device, dtype=dtype) + + # Ntops implementation + ntops_output = ntops.torch.adaptive_max_pool2d( + input_tensor, + output_size + ) + + # Reference implementation + reference_output = torch.nn.functional.adaptive_max_pool2d( + input_tensor, + output_size + ) + + assert torch.allclose(ntops_output, reference_output, atol=1e-3, rtol=1e-3) \ No newline at end of file diff --git a/tests/test_atan.py b/tests/test_atan.py new file mode 100644 index 0000000..b20f49b --- /dev/null +++ b/tests/test_atan.py @@ -0,0 +1,120 @@ +import pytest +import torch +import math + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_atan_basic(shape, dtype, device, rtol, atol): + """ + 基本功能测试:对比 PyTorch 原生实现 + """ + # atan 定义域是全实数,生成范围在 [-10, 10] 的随机数 + input_tensor = (torch.rand(shape, dtype=dtype, device=device) - 0.5) * 20.0 + + reference_output = torch.atan(input_tensor) + ntops_output = ntops.torch.atan(input_tensor) + + assert torch.allclose(ntops_output, reference_output, rtol=rtol, atol=atol) + assert ntops_output.shape == input_tensor.shape + assert ntops_output.dtype == input_tensor.dtype + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_atan_boundary_values(dtype, device): + """ + 边界值测试:0, 1, -1, inf, -inf + """ + test_values = torch.tensor([ + 0.0, + 1.0, + -1.0, + float('inf'), + -float('inf') + ], dtype=dtype, device=device) + + reference_output = torch.atan(test_values) + ntops_output = ntops.torch.atan(test_values) + + # 验证数值精度 + # float16 精度较低,适当放宽误差 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + + assert torch.allclose(ntops_output, reference_output, rtol=rtol, atol=atol) + + # 显式验证特殊数学性质 + # atan(0) = 0 + assert torch.abs(ntops_output[0]) < 1e-6 + # atan(inf) = pi/2 + assert torch.abs(ntops_output[3] - (math.pi / 2)) < (1e-3 if dtype == torch.float16 else 1e-6) + # atan(-inf) = -pi/2 + assert torch.abs(ntops_output[4] - (-math.pi / 2)) < (1e-3 if dtype == torch.float16 else 1e-6) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_atan_strided_output(dtype): + """ + 测试非连续内存布局 (Strided Output) + """ + device = "cuda" + input_tensor = torch.randn(4, 5, 6, dtype=dtype, device=device) + reference_output = torch.atan(input_tensor) + + # 创建非连续输出张量 + large_tensor = torch.empty(4, 5, 8, dtype=dtype, device=device) + out = large_tensor[:, :, :6] + + result = ntops.torch.atan(input_tensor, out=out) + + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out + assert not out.is_contiguous() + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_atan_inplace_compatibility(dtype): + """ + 测试与原地操作的兼容性 + """ + device = "cuda" + input_tensor = torch.randn(10, 10, dtype=dtype, device=device) + input_copy = input_tensor.clone() + + # 链式操作 + input_copy.mul_(2.0) + input_copy.add_(0.5) + + reference_output = torch.atan(input_copy) + ntops_output = ntops.torch.atan(input_copy) + + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + + assert torch.allclose(ntops_output, reference_output, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_atan_nan_handling(dtype): + """ + 测试 NaN 输入处理 + """ + device = "cuda" + input_tensor = torch.tensor([float('nan'), 1.0], dtype=dtype, device=device) + + output = ntops.torch.atan(input_tensor) + + assert torch.isnan(output[0]) + assert not torch.isnan(output[1]) \ No newline at end of file diff --git a/tests/test_batch_norm.py b/tests/test_batch_norm.py new file mode 100644 index 0000000..7d65564 --- /dev/null +++ b/tests/test_batch_norm.py @@ -0,0 +1,44 @@ +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + +@skip_if_cuda_not_available +@pytest.mark.parametrize("eps", (1e-5, 1e-3)) +@pytest.mark.parametrize("affine", (True, False)) +@pytest.mark.parametrize(*generate_arguments()) +def test_batch_norm( + shape, dtype, device, rtol, atol, affine, eps +): + if len(shape) < 2: + return + + input = torch.randn(shape, dtype=dtype, device=device) + C = shape[1] + + if affine: + weight = torch.randn(C, dtype=dtype, device=device) + bias = torch.randn(C, dtype=dtype, device=device) + else: + weight = None + bias = None + + # 调用 DSL 实现 + ninetoothed_output = ntops.torch.batch_norm( + input, weight=weight, bias=bias, eps=eps, training=True + ) + + # 调用 PyTorch 参考 + # 必须指定 training=True 以强制从当前 batch 计算统计量 + reference_output = torch.nn.functional.batch_norm( + input, + running_mean=None, + running_var=None, + weight=weight, + bias=bias, + training=True, + eps=eps + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file diff --git a/tests/test_bincount.py b/tests/test_bincount.py new file mode 100644 index 0000000..64a8c8e --- /dev/null +++ b/tests/test_bincount.py @@ -0,0 +1,47 @@ +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available + +def generate_bincount_args(): + # 参数组合: (size, has_weights, minlength) + cases = [ + (100, False, 0), + (100, True, 0), + (100, True, 50), + (1000, False, 2000), # minlength > max_val + (50, True, 10), # minlength < max_val + (0, False, 5), # Empty input + ] + return "size, has_weights, minlength", cases + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_bincount_args()) +def test_bincount(size, has_weights, minlength): + device = "cuda" if torch.cuda.is_available() else "cpu" + + # 生成随机非负整数输入 + if size > 0: + max_val = size # 保证值有一定的分布 + input_tensor = torch.randint(0, max_val, (size,), device=device, dtype=torch.int32) + else: + input_tensor = torch.tensor([], device=device, dtype=torch.int32) + + weights = None + if has_weights and size > 0: + weights = torch.randn(size, device=device, dtype=torch.float32) + elif has_weights and size == 0: + weights = torch.tensor([], device=device, dtype=torch.float32) + + # 运行参考实现 (PyTorch) + ref_out = torch.bincount(input_tensor, weights=weights, minlength=minlength) + + # 运行 ntops 实现 + ntops_out = ntops.torch.bincount(input_tensor, weights=weights, minlength=minlength) + + # 比较结果 + # 注意: 浮点数比较需要一定容差 + if ntops_out.is_floating_point(): + assert torch.allclose(ntops_out, ref_out, atol=1e-4) + else: + assert torch.equal(ntops_out, ref_out) \ No newline at end of file diff --git a/tests/test_maximum.py b/tests/test_maximum.py new file mode 100644 index 0000000..0be30f2 --- /dev/null +++ b/tests/test_maximum.py @@ -0,0 +1,62 @@ +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_maximum_elementwise(shape, dtype, device, rtol, atol): + """测试基础的逐元素 maximum (形状相同)""" + # 随机生成测试数据 + input_tensor = torch.randn(shape, dtype=dtype, device=device) + other_tensor = torch.randn(shape, dtype=dtype, device=device) + + # 1. 运行你的 DSL 实现 + # 注意:确保 ntops.torch.maximum 已经正确暴露 + ntops_result = ntops.torch.maximum(input_tensor, other_tensor) + + # 2. 运行 PyTorch 参考实现 + reference_result = torch.maximum(input_tensor, other_tensor) + + # 3. 验证结果 + assert torch.allclose(ntops_result, reference_result, rtol=rtol, atol=atol) + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_maximum_broadcasting(dtype): + """测试广播机制 (例如: [4, 1, 32] vs [1, 64, 32])""" + device = "cuda" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + shape_a = (4, 1, 32) + shape_b = (1, 64, 32) + expected_shape = (4, 64, 32) + + input_tensor = torch.randn(shape_a, dtype=dtype, device=device) + other_tensor = torch.randn(shape_b, dtype=dtype, device=device) + + ntops_result = ntops.torch.maximum(input_tensor, other_tensor) + reference_result = torch.maximum(input_tensor, other_tensor) + + # 检查形状是否正确广播 + assert ntops_result.shape == expected_shape + # 检查数值正确性 + assert torch.allclose(ntops_result, reference_result) + +@skip_if_cuda_not_available +def test_maximum_out_variant(): + """测试 out= 参数""" + device = "cuda" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x = torch.randn(100, device=device) + y = torch.randn(100, device=device) + out = torch.empty_like(x) + + ntops.torch.maximum(x, y, out=out) + expected = torch.maximum(x, y) + + assert torch.allclose(out, expected) \ No newline at end of file