Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/ntops/kernels/gcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor
from ntops.kernels.element_wise import arrangement


def application(input, other, output):
a = ntl.abs(ntl.cast(input, ntl.int64))
b = ntl.abs(ntl.cast(other, ntl.int64))

while ntl.max(ntl.cast(b != 0, ntl.int32)) == 1:
mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

output = ntl.cast(a, output.dtype)


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, other=0),
Tensor(ndim, dtype=dtype, other=0),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
46 changes: 46 additions & 0 deletions src/ntops/kernels/glu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import functools
import ninetoothed.language as ntl
from ninetoothed import Tensor

def arrangement(input, output, dim_size, dim, block_size):
ndim = input.ndim
if dim < 0: dim = ndim + dim

tile_shape = [1] * ndim
tile_shape[dim] = block_size

in_t = input.tile(tuple(tile_shape))
out_t = output.tile(tuple(tile_shape))

for _ in range(ndim - 1):

in_t.dtype = in_t.dtype.squeeze(0 if dim != 0 else 1)
out_t.dtype = out_t.dtype.squeeze(0 if dim != 0 else 1)

if dim > 0:
dim -= 1

return in_t, out_t, dim_size

def application(input, output, dim_size):
half = dim_size // 2

for i in range(half):
a = ntl.cast(input[i], ntl.float32)
b = ntl.cast(input[i + half], ntl.float32)

res = a * ntl.sigmoid(b)

output[i] = ntl.cast(res, output.dtype)

def premake(ndim, dim, dim_size, dtype=None, block_size=None):

arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(0, constexpr=True, value=dim_size),
)

return arrangement_, application, tensors
46 changes: 46 additions & 0 deletions src/ntops/kernels/select_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import functools
import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, src, output, index, dim_size_pow2, dim, block_size):
ndim = input.ndim
if dim < 0: dim += ndim
non_target_dims = tuple(i for i in range(ndim) if i != dim)

def _arrangement(t):
return t.permute(non_target_dims + (dim,)).flatten(end_dim=-1)

# (Remaining, Dim_Size)
input_arranged = _arrangement(input).tile((block_size, -1)).squeeze(1)
src_arranged = _arrangement(src).tile((block_size, -1)).squeeze(1)
output_arranged = _arrangement(output).tile((block_size, -1)).squeeze(1)

return input_arranged, src_arranged, output_arranged, index, dim_size_pow2

def application(input, src, output, target_index, dim_size_pow2):
col_indices = ntl.arange(0, dim_size_pow2)

col_indices = ntl.expand_dims(col_indices, 0)
col_indices = ntl.broadcast_to(col_indices, (input.shape[0], dim_size_pow2))

actual_dim_size = input.shape[1]

match_mask = (col_indices == ntl.cast(target_index, ntl.int32))
valid_mask = col_indices < ntl.cast(actual_dim_size, ntl.int32)

final_mask = match_mask & valid_mask

output = ntl.where(final_mask, ntl.cast(src, output.dtype), ntl.cast(input, output.dtype))

def premake(ndim, dim, index, dim_size_pow2, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(0, constexpr=True, value=index),
Tensor(0, constexpr=True, value=dim_size_pow2),
)
return arrangement_, application, tensors
12 changes: 12 additions & 0 deletions src/ntops/torch/gcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
import ntops
from ntops.torch.utils import _cached_make

def gcd(input, other, out=None):
if out is None:
out = torch.empty_like(input)

block_size = 1024
kernel = _cached_make(ntops.kernels.gcd.premake, input.ndim, input.dtype, block_size)
kernel(input, other, out)
return out
25 changes: 25 additions & 0 deletions src/ntops/torch/glu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import ntops
from ntops.torch.utils import _cached_make

def glu(input, dim=-1):
ndim = input.ndim
if dim < 0: dim = ndim + dim

dim_size = input.size(dim)
out_shape = list(input.shape)
out_shape[dim] //= 2
output = torch.empty(out_shape, dtype=input.dtype, device=input.device)
block_size = 1024

kernel = _cached_make(
ntops.kernels.glu.premake,
ndim,
dim,
dim_size,
input.dtype,
block_size
)

kernel(input, output, dim_size)
return output
23 changes: 23 additions & 0 deletions src/ntops/torch/select_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import ntops
from ntops.torch.utils import _cached_make

def select_scatter(input, src, dim, index):
ndim = input.ndim
if dim < 0: dim += ndim

dim_size = input.shape[dim]
dim_size_pow2 = 1 << (dim_size - 1).bit_length()

src_expanded = src.unsqueeze(dim)
output = torch.empty_like(input)
block_size = 1024

kernel = _cached_make(
ntops.kernels.select_scatter.premake,
ndim, dim, int(index), int(dim_size_pow2),
input.dtype, block_size
)

kernel(input, src_expanded, output, int(index), int(dim_size_pow2))
return output