-
Notifications
You must be signed in to change notification settings - Fork 56
GPU-friendly truncation implementations #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8df520e
1481184
1e46b0d
8ef0425
8423ce8
848f0cc
7ae9b05
b1fe3bd
94ecfca
7395b8a
9af19b7
180afe6
efbe088
eafd7a8
f4892cf
3f273a1
6842a70
5bb2a23
ddd0ed6
f3b45ef
f26cffe
6ff9ac8
6666459
ebbdb84
2d7338a
bde9c50
d603b9e
5bc1506
b44878f
4a722ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,5 +17,6 @@ using TensorKit: MatrixAlgebraKit | |
| using Random | ||
|
|
||
| include("cutensormap.jl") | ||
| include("truncation.jl") | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <:CuVector{T}} | ||
|
|
||
| function MatrixAlgebraKit.findtruncated( | ||
| values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder | ||
| ) | ||
| I = sectortype(values) | ||
|
|
||
| dims = similar(values, Base.promote_op(dim, I)) | ||
| for (c, v) in pairs(dims) | ||
| fill!(v, dim(c)) | ||
| end | ||
|
|
||
| perm = sortperm(parent(values); strategy.by, strategy.rev) | ||
| cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) | ||
|
|
||
| result = similar(values, Bool) | ||
| parent(result)[perm] .= cumulative_dim .<= strategy.howmany | ||
| return result | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.findtruncated( | ||
| values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByError | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably needs to be special cased for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just checked, and at least in MatrixAlgebraKit we are currently not supporting this. While I don't necessarily have any strong opinions in favor or against this, I will for now simply add a check and leave this as TBA? |
||
| (isfinite(strategy.p) && strategy.p > 0) || | ||
| throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported.")) | ||
| ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) | ||
| ϵᵖ = similar(values, typeof(ϵᵖmax)) | ||
|
|
||
| # dimensions are all 1 so no need to account for weight | ||
| if FusionStyle(sectortype(values)) isa UniqueFusion | ||
| parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p | ||
| else | ||
| for (c, v) in pairs(values) | ||
| v′ = ϵᵖ[c] | ||
| v′ .= abs.(v) .^ strategy.p .* dim(c) | ||
| end | ||
| end | ||
|
|
||
| perm = sortperm(parent(values); by = abs, rev = false) | ||
| cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) | ||
|
|
||
| result = similar(values, Bool) | ||
| parent(result)[perm] .= cumulative_err .> ϵᵖmax | ||
| return result | ||
| end | ||
|
|
||
| # Needed until MatrixAlgebraKit patch hits... | ||
kshyatt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int}) | ||
| result = fill!(similar(A), false) | ||
| result[B] .= @view A[B] | ||
kshyatt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return result | ||
| end | ||
lkdvos marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.