Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8df520e
try to make truncation GPU-friendly
lkdvos Jan 8, 2026
1481184
Temporarily fix StridedViews version
kshyatt Jan 20, 2026
1e46b0d
Revert "Temporarily fix StridedViews version"
kshyatt Jan 20, 2026
8ef0425
Small update for diagonal pullbacks
kshyatt Jan 22, 2026
8423ce8
Fix last error
kshyatt Jan 22, 2026
848f0cc
Reenable truncated CUDA tests
kshyatt Jan 22, 2026
7ae9b05
make truncation run on GPU
lkdvos Jan 22, 2026
b1fe3bd
bypass scalar indexing by specializing
lkdvos Jan 22, 2026
94ecfca
convenience overloads
lkdvos Jan 22, 2026
7395b8a
gpu-friendly copies
lkdvos Jan 22, 2026
9af19b7
retain storagetype in extended_S
lkdvos Jan 22, 2026
180afe6
avoid GPU issues with truncated adjoint tensormaps
lkdvos Jan 22, 2026
efbe088
various utility improvements
lkdvos Jan 22, 2026
eafd7a8
complete rewrite of implementation
lkdvos Jan 22, 2026
f4892cf
GPU doesn't like `trues`
lkdvos Jan 22, 2026
3f273a1
remove CUDA specializations and temporarily add missing MatrixAlgebra…
lkdvos Jan 22, 2026
6842a70
better dimension testing
lkdvos Jan 22, 2026
5bb2a23
fix unbound type parameter
lkdvos Jan 22, 2026
ddd0ed6
add missing import
lkdvos Jan 22, 2026
f3b45ef
be careful about double method definitions
lkdvos Jan 22, 2026
f26cffe
disable diagonal test
lkdvos Jan 23, 2026
6ff9ac8
bump MatrixAlgebraKit dependency
lkdvos Jan 23, 2026
6666459
Revert "disable diagonal test"
lkdvos Jan 23, 2026
ebbdb84
remove unnecessary specializations
lkdvos Jan 23, 2026
2d7338a
specialize CPU implementations
lkdvos Jan 23, 2026
bde9c50
add explanation TruncationByOrder
lkdvos Jan 26, 2026
d603b9e
add explanation and specialization TruncationByError
lkdvos Jan 26, 2026
5bc1506
fix stupidity
lkdvos Jan 26, 2026
b44878f
fix views
lkdvos Jan 27, 2026
4a722ef
enfore positive and finite p-norms
lkdvos Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ FiniteDifferences = "0.12"
GPUArrays = "11.3.1"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.2"
MatrixAlgebraKit = "0.6.3"
Mooncake = "0.4.183"
OhMyThreads = "0.8.0"
Printf = "1"
Expand Down
1 change: 1 addition & 0 deletions ext/TensorKitCUDAExt/TensorKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ using TensorKit: MatrixAlgebraKit
using Random

include("cutensormap.jl")
include("truncation.jl")

end
52 changes: 52 additions & 0 deletions ext/TensorKitCUDAExt/truncation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <:CuVector{T}}

function MatrixAlgebraKit.findtruncated(
values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder
)
I = sectortype(values)

dims = similar(values, Base.promote_op(dim, I))
for (c, v) in pairs(dims)
fill!(v, dim(c))
end

perm = sortperm(parent(values); strategy.by, strategy.rev)
cumulative_dim = cumsum(Base.permute!(parent(dims), perm))

result = similar(values, Bool)
parent(result)[perm] .= cumulative_dim .<= strategy.howmany
return result
end

function MatrixAlgebraKit.findtruncated(
values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByError
)
Copy link
Member

@Jutho Jutho Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably needs to be special cased for p == Inf, if we want to support this. That in itself should be quite an easy case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked, and at least in MatrixAlgebraKit we are currently not supporting this. While I don't necessarily have any strong opinions in favor or against this, I will for now simply add a check and leave this as TBA?

(isfinite(strategy.p) && strategy.p > 0) ||
throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported."))
ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p))
ϵᵖ = similar(values, typeof(ϵᵖmax))

# dimensions are all 1 so no need to account for weight
if FusionStyle(sectortype(values)) isa UniqueFusion
parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p
else
for (c, v) in pairs(values)
v′ = ϵᵖ[c]
v′ .= abs.(v) .^ strategy.p .* dim(c)
end
end

perm = sortperm(parent(values); by = abs, rev = false)
cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm))

result = similar(values, Bool)
parent(result)[perm] .= cumulative_err .> ϵᵖmax
return result
end

# Needed until MatrixAlgebraKit patch hits...
function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int})
result = fill!(similar(A), false)
result[B] .= @view A[B]
return result
end
25 changes: 21 additions & 4 deletions src/factorizations/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ _adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs.
_adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...)
_adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...)
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg))
_adjoint(alg::TruncatedAlgorithm) = TruncatedAlgorithm(_adjoint(alg.alg), alg.trunc)
_adjoint(alg::AbstractAlgorithm) = alg

_adjoint(alg::MAK.CUSOLVER_HouseholderQR) = MAK.LQViaTransposedQR(alg)
Expand Down Expand Up @@ -81,7 +82,7 @@ for (left_f, right_f) in zip(
end

# 3-arg functions
for f in (:svd_full, :svd_compact)
for f in (:svd_full, :svd_compact, :svd_trunc)
f! = Symbol(f, :!)
@eval function MAK.copy_input(::typeof($f), t::AdjointTensorMap)
return adjoint(MAK.copy_input($f, adjoint(t)))
Expand All @@ -93,9 +94,16 @@ for f in (:svd_full, :svd_compact)
return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg))))
end

@eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
return reverse(adjoint.(F′))
if f === :svd_trunc
function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
return Vᴴ', S, U', ϵ
end
else
@eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
return reverse(adjoint.(F′))
end
end

# disambiguate by prohibition
Expand All @@ -111,6 +119,15 @@ function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm)
F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
return reverse(adjoint.(F′))
end
function MAK.initialize_output(
::typeof(svd_trunc!), t::AdjointTensorMap, alg::TruncatedAlgorithm
)
return reverse(adjoint.(MAK.initialize_output(svd_trunc!, adjoint(t), _adjoint(alg))))
end
function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::TruncatedAlgorithm)
U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
return Vᴴ', S, U', ϵ
end

function LinearAlgebra.isposdef(t::AdjointTensorMap)
return isposdef(adjoint(t))
Expand Down
22 changes: 1 addition & 21 deletions src/factorizations/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,6 @@ for f in (
@eval MAK.copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d)
end

for f! in (:eig_full!, :eig_trunc!)
@eval function MAK.initialize_output(
::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm
)
return d, similar(d)
end
end

for f! in (:eigh_full!, :eigh_trunc!)
@eval function MAK.initialize_output(
::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm
)
if scalartype(d) <: Real
return d, similar(d, space(d))
else
return similar(d, real(scalartype(d))), similar(d, space(d))
end
end
end

for f! in (:qr_full!, :qr_compact!)
@eval function MAK.initialize_output(
::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm
Expand Down Expand Up @@ -93,7 +73,7 @@ end
# For diagonal inputs we don't have to promote the scalartype since we know they are symmetric
function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::DiagonalAlgorithm)
V_D = fuse(domain(t))
Tc = scalartype(t)
Tc = complex(scalartype(t))
A = similarstoragetype(t, Tc)
return SectorVector{Tc, sectortype(t), A}(undef, V_D)
end
Loading