diff --git a/Project.toml b/Project.toml index 934bdb6ed..ddcbe4141 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ FiniteDifferences = "0.12" GPUArrays = "11.3.1" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.2" +MatrixAlgebraKit = "0.6.3" Mooncake = "0.4.183" OhMyThreads = "0.8.0" Printf = "1" diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index 417970a02..f5efb98bb 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -17,5 +17,6 @@ using TensorKit: MatrixAlgebraKit using Random include("cutensormap.jl") +include("truncation.jl") end diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl new file mode 100644 index 000000000..019ded97b --- /dev/null +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -0,0 +1,52 @@ +const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <:CuVector{T}} + +function MatrixAlgebraKit.findtruncated( + values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder + ) + I = sectortype(values) + + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) + end + + perm = sortperm(parent(values); strategy.by, strategy.rev) + cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_dim .<= strategy.howmany + return result +end + +function MatrixAlgebraKit.findtruncated( + values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByError + ) + (isfinite(strategy.p) && strategy.p > 0) || + throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported.")) + ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) + ϵᵖ = similar(values, typeof(ϵᵖmax)) + + # dimensions are all 1 so no need to account for weight + if FusionStyle(sectortype(values)) isa UniqueFusion + parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p + else + for (c, v) in pairs(values) + v′ = ϵᵖ[c] + v′ .= abs.(v) .^ strategy.p .* dim(c) + end + end + + perm = sortperm(parent(values); by = abs, rev = false) + cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_err .> ϵᵖmax + return result +end + +# Needed until MatrixAlgebraKit patch hits... +function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int}) + result = fill!(similar(A), false) + result[B] .= @view A[B] + return result +end diff --git a/src/factorizations/adjoint.jl b/src/factorizations/adjoint.jl index 20d6d5986..eae8989ce 100644 --- a/src/factorizations/adjoint.jl +++ b/src/factorizations/adjoint.jl @@ -7,6 +7,7 @@ _adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs. _adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...) _adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...) _adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg)) +_adjoint(alg::TruncatedAlgorithm) = TruncatedAlgorithm(_adjoint(alg.alg), alg.trunc) _adjoint(alg::AbstractAlgorithm) = alg _adjoint(alg::MAK.CUSOLVER_HouseholderQR) = MAK.LQViaTransposedQR(alg) @@ -81,7 +82,7 @@ for (left_f, right_f) in zip( end # 3-arg functions -for f in (:svd_full, :svd_compact) +for f in (:svd_full, :svd_compact, :svd_trunc) f! = Symbol(f, :!) @eval function MAK.copy_input(::typeof($f), t::AdjointTensorMap) return adjoint(MAK.copy_input($f, adjoint(t))) @@ -93,9 +94,16 @@ for f in (:svd_full, :svd_compact) return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg)))) end - @eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) - F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) - return reverse(adjoint.(F′)) + if f === :svd_trunc + function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) + U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return Vᴴ', S, U', ϵ + end + else + @eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) + F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) + end end # disambiguate by prohibition @@ -111,6 +119,15 @@ function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm) F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return reverse(adjoint.(F′)) end +function MAK.initialize_output( + ::typeof(svd_trunc!), t::AdjointTensorMap, alg::TruncatedAlgorithm + ) + return reverse(adjoint.(MAK.initialize_output(svd_trunc!, adjoint(t), _adjoint(alg)))) +end +function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::TruncatedAlgorithm) + U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return Vᴴ', S, U', ϵ +end function LinearAlgebra.isposdef(t::AdjointTensorMap) return isposdef(adjoint(t)) diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index bdcfebd74..dae550ea1 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -13,26 +13,6 @@ for f in ( @eval MAK.copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) end -for f! in (:eig_full!, :eig_trunc!) - @eval function MAK.initialize_output( - ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm - ) - return d, similar(d) - end -end - -for f! in (:eigh_full!, :eigh_trunc!) - @eval function MAK.initialize_output( - ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm - ) - if scalartype(d) <: Real - return d, similar(d, space(d)) - else - return similar(d, real(scalartype(d))), similar(d, space(d)) - end - end -end - for f! in (:qr_full!, :qr_compact!) @eval function MAK.initialize_output( ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm @@ -93,7 +73,7 @@ end # For diagonal inputs we don't have to promote the scalartype since we know they are symmetric function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::DiagonalAlgorithm) V_D = fuse(domain(t)) - Tc = scalartype(t) + Tc = complex(scalartype(t)) A = similarstoragetype(t, Tc) return SectorVector{Tc, sectortype(t), A}(undef, V_D) end diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index b9d060fec..e8e113ec1 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -29,15 +29,19 @@ end # --------- _blocklength(d::Integer, ind) = _blocklength(Base.OneTo(d), ind) _blocklength(ax, ind) = length(ax[ind]) +_blocklength(ax::Base.OneTo, ind::AbstractVector{<:Integer}) = length(ind) +_blocklength(ax::Base.OneTo, ind::AbstractVector{Bool}) = count(ind) + function truncate_space(V::ElementarySpace, inds) - return spacetype(V)(c => _blocklength(dim(V, c), ind) for (c, ind) in inds) + return spacetype(V)(c => _blocklength(dim(V, c), ind) for (c, ind) in pairs(inds)) end function truncate_domain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, inds) for (c, b) in blocks(tdst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, view(block(tsrc, c), :, I)) + b′ = block(tsrc, c) + b .= view(b′, :, I) end return tdst end @@ -45,7 +49,8 @@ function truncate_codomain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, in for (c, b) in blocks(tdst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, view(block(tsrc, c), I, :)) + b′ = block(tsrc, c) + b .= view(b′, I, :) end return tdst end @@ -53,7 +58,7 @@ function truncate_diagonal!(Ddst::DiagonalTensorMap, Dsrc::DiagonalTensorMap, in for (c, b) in blocks(Ddst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(diagview(b), view(diagview(block(Dsrc, c)), I)) + diagview(b) .= view(diagview(block(Dsrc, c)), I) end return Ddst end @@ -78,7 +83,7 @@ end function MAK.truncate( ::typeof(left_null!), (U, S)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy ) - extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(codomain(U)))) + extended_S = zerovector!(SectorVector{eltype(S), sectortype(S), storagetype(S)}(undef, fuse(codomain(U)))) for (c, b) in blocks(S) copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter end @@ -91,7 +96,7 @@ end function MAK.truncate( ::typeof(right_null!), (S, Vᴴ)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy ) - extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(domain(Vᴴ)))) + extended_S = zerovector!(SectorVector{eltype(S), sectortype(S), storagetype(S)}(undef, fuse(domain(Vᴴ)))) for (c, b) in blocks(S) copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter end @@ -142,57 +147,11 @@ for f! in (:eig_trunc!, :eigh_trunc!) end end -# Find truncation -# --------------- +# findtruncated +# ------------- # auxiliary functions rtol_to_atol(S, p, atol, rtol) = rtol == 0 ? atol : max(atol, norm(S, p) * rtol) -function _compute_truncerr(Σdata, truncdim, p = 2) - I = keytype(Σdata) - S = scalartype(valtype(Σdata)) - return TensorKit._norm( - (c => @view(v[(get(truncdim, c, 0) + 1):end]) for (c, v) in Σdata), - p, zero(S) - ) -end - -function _findnexttruncvalue( - S, truncdim::SectorDict{I, Int}; by = identity, rev::Bool = true - ) where {I <: Sector} - # early return - (isempty(S) || all(iszero, values(truncdim))) && return nothing - if rev - σmin, imin = findmin(keys(truncdim)) do c - d = truncdim[c] - return by(S[c][d]) - end - return σmin, keys(truncdim)[imin] - else - σmax, imax = findmax(keys(truncdim)) do c - d = truncdim[c] - return by(S[c][d]) - end - return σmax, keys(truncdim)[imax] - end -end - -function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false) - values_sorted = similar(values) - perms = SectorDict( - ( - begin - p = sortperm(v; by, rev) - vs = values_sorted[c] - vs .= view(v, p) - c => p - end - ) for (c, v) in pairs(values) - ) - return values_sorted, perms -end - -# findtruncated -# ------------- # Generic fallback function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationStrategy) return MAK.findtruncated(values, strategy) @@ -202,25 +161,46 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation) return SectorDict(c => Colon() for c in keys(values)) end +# Need to select the first k values here after sorting across blocks, weighted by quantum dimension +# The strategy is therefore to sort all values, and then use a logical array to indicate +# which ones to keep. +# For GenericFusion, we additionally keep a vector of the quantum dimensions to provide the +# correct weight function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) - values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) - inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) - return SectorDict(c => perms[c][I] for (c, I) in inds) -end -function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) - I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) - totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0) - while totaldim > strategy.howmany - next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev) - isnothing(next) && break - _, cmin = next - truncdim[cmin] -= 1 - totaldim -= dim(cmin) - truncdim[cmin] == 0 && delete!(truncdim, cmin) + I = sectortype(values) + + # dimensions are all 1 so no need to account for weight + if FusionStyle(I) isa UniqueFusion + perm = partialsortperm(parent(values), 1:strategy.howmany; strategy.by, strategy.rev) + result = similar(values, Bool) + fill!(parent(result), false) + parent(result)[perm] .= true + return result + end + + # allocate vector of weights for each value + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) end - return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim) + + # allocate logical array for the output + result = similar(values, Bool) + fill!(parent(result), false) + + # loop over sorted values and mark as to keep until dimension is reached + totaldim = 0 + for i in sortperm(parent(values); strategy.by, strategy.rev) + totaldim += dims[i] + totaldim > strategy.howmany && break + result[i] = true + end + + return result end +# disambiguate +MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) = + MAK.findtruncated(values, strategy) function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter) return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values)) @@ -237,28 +217,43 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByValue return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in pairs(values)) end +# Need to select the first k values here after sorting by error across blocks, +# where k is determined by the cumulative truncation error of these values. +# The strategy is therefore to sort all values, and then use a logical array to indicate +# which ones to keep. function MAK.findtruncated(values::SectorVector, strategy::TruncationByError) - values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) - inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) - return SectorDict(c => perms[c][I] for (c, I) in inds) -end -function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError) - I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) - by(c, v) = abs(v)^strategy.p * dim(c) - Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), pairs(values)) - ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ) - truncerrᵖ = zero(real(scalartype(valtype(values)))) - next = _findnexttruncvalue(values, truncdim) - while !isnothing(next) - σmin, cmin = next - truncerrᵖ += by(cmin, σmin) - truncerrᵖ >= ϵᵖ && break - (truncdim[cmin] -= 1) == 0 && delete!(truncdim, cmin) - next = _findnexttruncvalue(values, truncdim) + (isfinite(strategy.p) && strategy.p > 0) || + throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported.")) + ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) + ϵᵖ = similar(values, typeof(ϵᵖmax)) + + # dimensions are all 1 so no need to account for weight + if FusionStyle(sectortype(values)) isa UniqueFusion + parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p + else + for (c, v) in pairs(values) + v′ = ϵᵖ[c] + v′ .= abs.(v) .^ strategy.p .* dim(c) + end end - return SectorDict{I, Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) + + # allocate logical array for the output + result = similar(values, Bool) + fill!(parent(result), true) + + # loop over sorted values and mark as to discard until maximal error is reached + totalerr = zero(eltype(ϵᵖ)) + for i in sortperm(parent(values); by = abs, rev = false) + totalerr += ϵᵖ[i] + totalerr > ϵᵖmax && break + result[i] = false + end + + return result end +# disambiguate +MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError) = + MAK.findtruncated(values, strategy) function MAK.findtruncated(values::SectorVector, strategy::TruncationSpace) blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) @@ -273,8 +268,7 @@ function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersectio inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) return SectorDict( c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds; - init = trues(length(values[c])) + Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds ) for c in intersect(map(keys, inds)...) ) end @@ -282,8 +276,7 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationInterse inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) return SectorDict( c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds; - init = trues(length(values[c])) + Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds ) for c in intersect(map(keys, inds)...) ) end @@ -293,7 +286,7 @@ end MAK.truncation_error(values::SectorVector, ind) = MAK.truncation_error!(copy(values), ind) function MAK.truncation_error!(values::SectorVector, ind) - for (c, ind_c) in ind + for (c, ind_c) in pairs(ind) v = values[c] v[ind_c] .= zero(eltype(v)) end diff --git a/src/tensors/sectorvector.jl b/src/tensors/sectorvector.jl index 4c914d6de..8f133b20e 100644 --- a/src/tensors/sectorvector.jl +++ b/src/tensors/sectorvector.jl @@ -36,6 +36,7 @@ Base.size(v::SectorVector, args...) = size(parent(v), args...) Base.similar(v::SectorVector) = SectorVector(similar(v.data), v.structure) Base.similar(v::SectorVector, ::Type{T}) where {T} = SectorVector(similar(v.data, T), v.structure) +Base.similar(v::SectorVector, V::ElementarySpace) = SectorVector{eltype(v), sectortype(V), storagetype(v)}(undef, V) Base.copy(v::SectorVector) = SectorVector(copy(v.data), v.structure) @@ -53,11 +54,13 @@ Base.keys(v::SectorVector) = keys(v.structure) Base.values(v::SectorVector) = (v[c] for c in keys(v)) Base.pairs(v::SectorVector) = SectorDict(c => v[c] for c in keys(v)) +Base.get(v::SectorVector{<:Any, I}, key::I, default) where {I} = haskey(v, key) ? v[key] : default +Base.haskey(v::SectorVector{<:Any, I}, key::I) where {I} = key in keys(v) + # TensorKit interface # ------------------- sectortype(::Type{T}) where {T <: SectorVector} = keytype(T) - -Base.similar(v::SectorVector, V::ElementarySpace) = SectorVector(undef, V) +storagetype(::Type{SectorVector{T, I, A}}) where {T, I, A} = A blocksectors(v::SectorVector) = keys(v) blocks(v::SectorVector) = pairs(v) @@ -108,3 +111,11 @@ LinearAlgebra.dot(v1::SectorVector, v2::SectorVector) = inner(v1, v2) function LinearAlgebra.norm(v::SectorVector, p::Real = 2) return _norm(blocks(v), p, float(zero(real(scalartype(v))))) end + +# Common functionality +# -------------------- +# specific overloads for performance and/or GPU +Base.minimum(x::SectorVector) = minimum(parent(x)) +Base.minimum(f, x::SectorVector) = minimum(f, parent(x)) +Base.maximum(x::SectorVector) = maximum(parent(x)) +Base.maximum(f, x::SectorVector) = maximum(f, parent(x)) diff --git a/test/cuda/factorizations.jl b/test/cuda/factorizations.jl index f7b6ad6d6..f3f15fe4b 100644 --- a/test/cuda/factorizations.jl +++ b/test/cuda/factorizations.jl @@ -229,17 +229,17 @@ for V in spacelist @test isisometric(N) @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) - #N = @constinferred left_null(t; trunc = (; atol = 100 * eps(norm(t)))) - #@test isisometric(N) - #@test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) + N = @constinferred left_null(t; trunc = (; atol = 100 * eps(norm(t)))) + @test isisometric(N) + @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) Nᴴ = @constinferred right_null(t; alg = :svd) @test isisometric(Nᴴ; side = :right) @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) - #Nᴴ = @constinferred right_null(t; trunc = (; atol = 100 * eps(norm(t)))) - #@test isisometric(Nᴴ; side = :right) - #@test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) + Nᴴ = @constinferred right_null(t; trunc = (; atol = 100 * eps(norm(t)))) + @test isisometric(Nᴴ; side = :right) + @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) end # empty tensor @@ -258,15 +258,15 @@ for V in spacelist end end - #=@testset "truncated SVD" begin + @testset "truncated SVD" begin for T in eltypes, t in ( CUDA.randn(T, W, W), - #CUDA.randn(T, W, W)', + CUDA.randn(T, W, W)', CUDA.randn(T, W, V4), CUDA.randn(T, V4, W), - #CUDA.randn(T, W, V4)', - #CUDA.randn(T, V4, W)', + CUDA.randn(T, W, V4)', + CUDA.randn(T, V4, W)', DiagonalTensorMap(CUDA.randn(T, reduceddim(V1)), V1), ) @@ -286,7 +286,7 @@ for V in spacelist @test isisometric(U1) @test isisometric(Vᴴ1; side = :right) @test norm(t - U1 * S1 * Vᴴ1) ≈ ϵ1 atol = eps(real(T))^(4 / 5) - @test dim(domain(S1)) <= nvals + test_dim_isapprox(domain(S1), nvals) λ = minimum(diagview(S1)) trunc = trunctol(; atol = λ - 10eps(λ)) @@ -325,9 +325,9 @@ for V in spacelist @test isisometric(Vᴴ5; side = :right) @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) @test minimum(diagview(S5)) >= λ - @test dim(domain(S5)) ≤ nvals + test_dim_isapprox(domain(S5), nvals) end - end=# # TODO + end @testset "Eigenvalue decomposition" begin for T in eltypes, @@ -335,7 +335,7 @@ for V in spacelist CUDA.rand(T, V1, V1), CUDA.rand(T, W, W), CUDA.rand(T, W, W)', - DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1), + # DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1), ) d, v = @constinferred eig_full(t) @@ -349,10 +349,10 @@ for V in spacelist @test @constinferred isposdef(vdv) t isa DiagonalTensorMap || @test !isposdef(t) # unlikely for non-hermitian map - #=nvals = round(Int, dim(domain(t)) / 2) + nvals = round(Int, dim(domain(t)) / 2) d, v = @constinferred eig_trunc(t; trunc = truncrank(nvals)) @test t * v ≈ v * d - @test dim(domain(d)) ≤ nvals=# + test_dim_isapprox(domain(d), nvals) t2 = @constinferred project_hermitian(t) D, V = eigen(t2) @@ -380,10 +380,9 @@ for V in spacelist @test isposdef(t2 - λ * one(t) + 0.1 * one(t2)) @test !isposdef(t2 - λ * one(t) - 0.1 * one(t2)) - # TODO - #=d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals)) + d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals)) @test t2 * v ≈ v * d - @test dim(domain(d)) ≤ nvals=# + test_dim_isapprox(domain(d), nvals) end end diff --git a/test/setup.jl b/test/setup.jl index 6cde01d28..5c8516eb9 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -3,9 +3,11 @@ module TestSetup export smallset, randsector, hasfusiontensor, force_planar export random_fusion export sectorlist +export test_dim_isapprox export Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VSU₂U₁, Vfib, VIB_diag, VIB_M using Random +using Test: @test using TensorKit using TensorKit: ℙ, PlanarTrivial using Base.Iterators: take, product @@ -88,6 +90,17 @@ function random_fusion(I::Type{<:Sector}, ::Val{N}) where {N} # for fusion tree return (s, tail...) end +# helper function to check that d - dim(c) < dim(V) <= d where c is the largest sector +# to allow for truncations to have some margin with larger sectors +function test_dim_isapprox(V::ElementarySpace, d::Int) + dim_c_max = maximum(dim, sectors(V); init = 1) + return @test max(0, d - dim_c_max) ≤ dim(V) ≤ d + dim_c_max +end +function test_dim_isapprox(V::ProductSpace, d::Int) + dim_c_max = maximum(dim, blocksectors(V); init = 1) + return @test max(0, d - dim_c_max) ≤ dim(V) ≤ d + dim_c_max +end + sectorlist = ( Z2Irrep, Z3Irrep, Z4Irrep, Z3Irrep ⊠ Z4Irrep, U1Irrep, CU1Irrep, SU2Irrep, diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index 176d62657..41f30567b 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -259,7 +259,7 @@ for V in spacelist @test isisometric(U1) @test isisometric(Vᴴ1; side = :right) @test norm(t - U1 * S1 * Vᴴ1) ≈ ϵ1 atol = eps(real(T))^(4 / 5) - @test dim(domain(S1)) <= nvals + test_dim_isapprox(domain(S1), nvals) λ = minimum(diagview(S1)) trunc = trunctol(; atol = λ - 10eps(λ)) @@ -298,7 +298,7 @@ for V in spacelist @test isisometric(Vᴴ5; side = :right) @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) @test minimum(diagview(S5)) >= λ - @test dim(domain(S5)) ≤ nvals + test_dim_isapprox(domain(S5), nvals) end end @@ -323,7 +323,7 @@ for V in spacelist nvals = round(Int, dim(domain(t)) / 2) d, v = @constinferred eig_trunc(t; trunc = truncrank(nvals)) @test t * v ≈ v * d - @test dim(domain(d)) ≤ nvals + test_dim_isapprox(domain(d), nvals) t2 = @constinferred project_hermitian(t) D, V = eigen(t2) @@ -353,7 +353,7 @@ for V in spacelist d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals)) @test t2 * v ≈ v * d - @test dim(domain(d)) ≤ nvals + test_dim_isapprox(domain(d), nvals) end end