diff --git a/src/adapt.jl b/src/adapt.jl index 3239951f..3794e45b 100644 --- a/src/adapt.jl +++ b/src/adapt.jl @@ -50,3 +50,43 @@ function Adapt.adapt_structure(to,v::PSparseMatrix) row_par = v.row_partition PSparseMatrix(matrix_partition,row_par,col_par,v.assembled) end + +function Adapt.adapt_structure(to,v::PVector) + new_local_values = map(local_values(v)) do myvals + Adapt.adapt_structure(to,myvals) + end + new_cache = Adapt.adapt_structure(to,v.cache) + new_v = PVector(new_local_values,v.index_partition, new_cache) + new_v +end + +function Adapt.adapt_structure(to, cache::SplitVectorAssemblyCache) + # Adapt all the components + neighbors_snd = cache.neighbors_snd + neighbors_rcv = cache.neighbors_rcv + buffer_snd = map(cache.buffer_snd) do ja + Adapt.adapt_structure(to, ja) + end + buffer_rcv = map(cache.buffer_rcv) do ja + Adapt.adapt_structure(to, ja) + end + exchange_setup = cache.exchange_setup + ghost_indices_snd = map(cache.ghost_indices_snd) do ja + Adapt.adapt_structure(to, ja) + end + own_indices_rcv = map(cache.own_indices_rcv) do ja + Adapt.adapt_structure(to, ja) + end + + # Create new cache with adapted components + SplitVectorAssemblyCache( + neighbors_snd, + neighbors_rcv, + ghost_indices_snd, + own_indices_rcv, + buffer_snd, + buffer_rcv, + exchange_setup, + false + ) +end diff --git a/test/adapt_tests.jl b/test/adapt_tests.jl index df5c835f..fa31d636 100644 --- a/test/adapt_tests.jl +++ b/test/adapt_tests.jl @@ -14,8 +14,8 @@ function Adapt.adapt_storage(::Type{<:FakeCuVector},x::AbstractArray) end function adapt_tests(distribute) - - rank = distribute(LinearIndices((2,2))) + parts_per_dir = (2,2) + rank = distribute(LinearIndices(parts_per_dir) a = [[1,2],[3,4,5],Int[],[3,4]] b = JaggedArray(a) @@ -61,4 +61,14 @@ function adapt_tests(distribute) @test typeof(val_b) == FakeCuVector{typeof(val_a)} @test val_b.vector == val_a end -end \ No newline at end of file + + p = prod(parts_per_dir) + ranks = distribute(LinearIndices((p,))) + nodes_per_dir = map(i->2*i,parts_per_dir) + args = laplacian_fdm(nodes_per_dir,parts_per_dir,ranks) + A = psparse(args...) |> fetch + Adapt.adapt(FakeCuVector, A) + b = pzeros(axes(A, 2), split_format=true) + Adapt.adapt(FakeCuVector, b) + +end