Skip to content

Commit 3bed38d

Browse files
committed
More tweaks
1 parent 227b5f3 commit 3bed38d

File tree

7 files changed

+94
-13
lines changed

7 files changed

+94
-13
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ using TensorKit.Factorizations
1010
using TensorKit.Strided
1111
using TensorKit.Factorizations: AbstractAlgorithm
1212
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13-
import TensorKit: randisometry, rand, randn
13+
import TensorKit: randisometry, rand, randn, _copyto!, _add_general_kernel_nonthreaded!, blocktype
1414

1515
using TensorKit: MatrixAlgebraKit
1616

1717
using Random
1818

1919
include("cutensormap.jl")
2020
include("truncation.jl")
21+
include("auxiliary.jl")
2122

2223
end

ext/TensorKitCUDAExt/auxiliary.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
function TensorKit._copyto!(A::StridedView{TA, 1, <:CuArray{TA}}, B::StridedView{TB, 2, <:CuArray{TB}}) where {TA, TB}
2+
length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))"))
3+
4+
Adata = parent(A)
5+
Astr = stride(A, 1)
6+
IA = A.offset
7+
8+
Bdata = parent(B)
9+
Bstr = strides(B)
10+
11+
IB_1 = B.offset
12+
# build index arrays
13+
IAs = Int[]
14+
IBs = Int[]
15+
@inbounds for _ in axes(B, 2)
16+
IB = IB_1
17+
for _ in axes(B, 1)
18+
IA += Astr
19+
append!(IAs, IA)
20+
IB += Bstr[1]
21+
append!(IBs, IB)
22+
end
23+
IB_1 += Bstr[2]
24+
end
25+
Adata[IAs] .= Bdata[IBs]
26+
27+
return A
28+
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@ function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂,
77
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
88
end
99

10+
#=function TensorKit.TensorMap{T, S₁, N₁, N₂, A}(
11+
::UndefInitializer, space::TensorMapSpace{S₂, N₁, N₂}
12+
) where {T, S₁, S₂ <: TensorKit.ElementarySpace, N₁, N₂, A <: CuVector{T}}
13+
d = TensorKit.fusionblockstructure(space).totaldim
14+
data = A(undef, d)
15+
if !isbitstype(T)
16+
zerovector!(data)
17+
end
18+
return TensorKit.TensorMap{T, S₂, A}(data, space)
19+
end=#
20+
1021
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
1122
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
1223
h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V)
@@ -17,6 +28,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr
1728
return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V)
1829
end
1930

31+
function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S}
32+
return SubArray{T, 1, CuVector{T, CUDA.DeviceMemory}, Tuple{UnitRange{Int}}, true}
33+
end
34+
2035
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
2136
@eval begin
2237
function CUDA.$fname(
@@ -102,9 +117,21 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
102117
end
103118

104119
function Base.convert(
105-
TT::Type{CuTensorMap{T, S, N₁, N₂}},
106-
t::AbstractTensorMap{<:Any, S, N₁, N₂}
107-
) where {T, S, N₁, N₂}
120+
TT::Type{TensorMap{T, S, N₁, N₂, A}},
121+
t::TensorMap{T, S, N₁, N₂, AA}
122+
) where {T, S, N₁, N₂, A <: CuArray{T}, AA}
123+
if typeof(t) === TT
124+
return t
125+
else
126+
tnew = TT(undef, space(t))
127+
return copy!(tnew, t)
128+
end
129+
end
130+
131+
function Base.convert(
132+
TT::Type{TensorMap{T, S, N₁, N₂, A}},
133+
t::AdjointTensorMap
134+
) where {T, S, N₁, N₂, A <: CuArray{T}}
108135
if typeof(t) === TT
109136
return t
110137
else
@@ -140,6 +167,8 @@ end
140167

141168
TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
142169
CuArray{T, N, CUDA.default_memory}
170+
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{CuArray{T, N}}) where {T, N} =
171+
CuArray{T, N, CUDA.default_memory}
143172

144173

145174
# CuTensorMap exponentation:
@@ -168,3 +197,21 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
168197
return tf
169198
end
170199
end
200+
201+
function TensorKit._add_general_kernel_nonthreaded!(
202+
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
203+
)
204+
# preallocate buffers
205+
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)
206+
207+
for subtransformer in transformer.data
208+
# Special case without intermediate buffers whenever there is only a single block
209+
if length(subtransformer[1]) == 1
210+
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
211+
else
212+
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
213+
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
214+
end
215+
end
216+
return nothing
217+
end

src/auxiliary/auxiliary.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
# Low-overhead implementation of `copyto!` for specific case of `stride(B, 1) < stride(B, 2)`
6161
# used in indexmanipulations: avoids the overhead of Strided.jl
6262
function _copyto!(A::StridedView{<:Any, 1}, B::StridedView{<:Any, 2})
63-
length(A) == length(B) || throw(DimensionMismatch())
63+
length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))"))
6464

6565
Adata = parent(A)
6666
Astr = stride(A, 1)

src/tensors/braidingtensor.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,15 @@ end
171171
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
172172
function add_transform!(
173173
tdst::AbstractTensorMap,
174-
tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple,
174+
tsrc::BraidingTensor{T, S},
175+
(p₁, p₂)::Index2Tuple,
175176
fusiontreetransform,
176177
α::Number, β::Number, backend::AbstractBackend...
177-
)
178+
) where {T, S}
179+
tsrc_map = TensorMapWithStorage{scalartype(tdst), storagetype(tdst)}(undef, (tsrc.V2 tsrc.V1) (tsrc.V1 tsrc.V2))
180+
copy!(tsrc_map, tsrc)
178181
return add_transform!(
179-
tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
182+
tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β,
180183
backend...
181184
)
182185
end

src/tensors/tensoroperations.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,10 @@ end
419419
# Scalar implementation
420420
#-----------------------
421421
function scalar(t::AbstractTensorMap{T, S, 0, 0}) where {T, S}
422-
Bs = collect(blocks(t))
423-
inds = findall(!iszero last, Bs)
424-
isempty(inds) && return zero(scalartype(t))
425-
return only(last(Bs[only(inds)]))
422+
Bs = blocks(t)
423+
B_ends = collect.(map(last, Bs))
424+
nz_B_ends = [!iszero.(B) for B in B_ends]
425+
valid_Bs = filter(any, B_ends)
426+
isempty(valid_Bs) && return zero(scalartype(t))
427+
return only(last(first(valid_Bs)))
426428
end

src/tensors/treetransformers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc)
4646
end
4747

4848
const _GenericTransformerData{T, N} = Tuple{
49-
Matrix{T},
49+
DenseMatrix{T},
5050
Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}},
5151
Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}},
5252
}

0 commit comments

Comments
 (0)