diff --git a/src/ntops/kernels/gcd.py b/src/ntops/kernels/gcd.py new file mode 100644 index 0000000..a4f733d --- /dev/null +++ b/src/ntops/kernels/gcd.py @@ -0,0 +1,49 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + a = ntl.abs(ntl.cast(input, ntl.int64)) + b = ntl.abs(ntl.cast(other, ntl.int64)) + + while ntl.max(ntl.cast(b != 0, ntl.int32)) == 1: + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + output = ntl.cast(a, output.dtype) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, other=0), + Tensor(ndim, dtype=dtype, other=0), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/glu.py b/src/ntops/kernels/glu.py new file mode 100644 index 0000000..2925735 --- /dev/null +++ b/src/ntops/kernels/glu.py @@ -0,0 +1,46 @@ +import functools +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement(input, output, dim_size, dim, block_size): + ndim = input.ndim + if dim < 0: dim = ndim + dim + + tile_shape = [1] * ndim + tile_shape[dim] = block_size + + in_t = input.tile(tuple(tile_shape)) + out_t = output.tile(tuple(tile_shape)) + + for _ in range(ndim - 1): + + in_t.dtype = in_t.dtype.squeeze(0 if dim != 0 else 1) + out_t.dtype = out_t.dtype.squeeze(0 if dim != 0 else 1) + + if dim > 0: + dim -= 1 + + return in_t, out_t, dim_size + +def application(input, output, dim_size): + half = dim_size // 2 + + for i in range(half): + a = ntl.cast(input[i], ntl.float32) + b = ntl.cast(input[i + half], ntl.float32) + + res = a * ntl.sigmoid(b) + + output[i] = ntl.cast(res, output.dtype) + +def premake(ndim, dim, dim_size, dtype=None, block_size=None): + + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(0, constexpr=True, value=dim_size), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/select_scatter.py b/src/ntops/kernels/select_scatter.py new file mode 100644 index 0000000..78ad155 --- /dev/null +++ b/src/ntops/kernels/select_scatter.py @@ -0,0 +1,46 @@ +import functools +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, src, output, index, dim_size_pow2, dim, block_size): + ndim = input.ndim + if dim < 0: dim += ndim + non_target_dims = tuple(i for i in range(ndim) if i != dim) + + def _arrangement(t): + return t.permute(non_target_dims + (dim,)).flatten(end_dim=-1) + + # (Remaining, Dim_Size) + input_arranged = _arrangement(input).tile((block_size, -1)).squeeze(1) + src_arranged = _arrangement(src).tile((block_size, -1)).squeeze(1) + output_arranged = _arrangement(output).tile((block_size, -1)).squeeze(1) + + return input_arranged, src_arranged, output_arranged, index, dim_size_pow2 + +def application(input, src, output, target_index, dim_size_pow2): + col_indices = ntl.arange(0, dim_size_pow2) + + col_indices = ntl.expand_dims(col_indices, 0) + col_indices = ntl.broadcast_to(col_indices, (input.shape[0], dim_size_pow2)) + + actual_dim_size = input.shape[1] + + match_mask = (col_indices == ntl.cast(target_index, ntl.int32)) + valid_mask = col_indices < ntl.cast(actual_dim_size, ntl.int32) + + final_mask = match_mask & valid_mask + + output = ntl.where(final_mask, ntl.cast(src, output.dtype), ntl.cast(input, output.dtype)) + +def premake(ndim, dim, index, dim_size_pow2, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(0, constexpr=True, value=index), + Tensor(0, constexpr=True, value=dim_size_pow2), + ) + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/gcd.py b/src/ntops/torch/gcd.py new file mode 100644 index 0000000..bd4e267 --- /dev/null +++ b/src/ntops/torch/gcd.py @@ -0,0 +1,12 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def gcd(input, other, out=None): + if out is None: + out = torch.empty_like(input) + + block_size = 1024 + kernel = _cached_make(ntops.kernels.gcd.premake, input.ndim, input.dtype, block_size) + kernel(input, other, out) + return out \ No newline at end of file diff --git a/src/ntops/torch/glu.py b/src/ntops/torch/glu.py new file mode 100644 index 0000000..7c1bab5 --- /dev/null +++ b/src/ntops/torch/glu.py @@ -0,0 +1,25 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def glu(input, dim=-1): + ndim = input.ndim + if dim < 0: dim = ndim + dim + + dim_size = input.size(dim) + out_shape = list(input.shape) + out_shape[dim] //= 2 + output = torch.empty(out_shape, dtype=input.dtype, device=input.device) + block_size = 1024 + + kernel = _cached_make( + ntops.kernels.glu.premake, + ndim, + dim, + dim_size, + input.dtype, + block_size + ) + + kernel(input, output, dim_size) + return output \ No newline at end of file diff --git a/src/ntops/torch/select_scatter.py b/src/ntops/torch/select_scatter.py new file mode 100644 index 0000000..d738266 --- /dev/null +++ b/src/ntops/torch/select_scatter.py @@ -0,0 +1,23 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def select_scatter(input, src, dim, index): + ndim = input.ndim + if dim < 0: dim += ndim + + dim_size = input.shape[dim] + dim_size_pow2 = 1 << (dim_size - 1).bit_length() + + src_expanded = src.unsqueeze(dim) + output = torch.empty_like(input) + block_size = 1024 + + kernel = _cached_make( + ntops.kernels.select_scatter.premake, + ndim, dim, int(index), int(dim_size_pow2), + input.dtype, block_size + ) + + kernel(input, src_expanded, output, int(index), int(dim_size_pow2)) + return output \ No newline at end of file