Skip to content

Commit 362962e

Browse files
committed
PodVector_*.from_* NumPy/CuPy
Helpers to copy data directly from `NumPy`/`CuPy` to simplify user code, e.g., to add particles in AMReX codes.
1 parent a1e7a2c commit 362962e

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed

src/amrex/extensions/PODVector.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,130 @@ def podvector_to_xp(self, copy=False):
9898
return self.to_cupy(copy) if amr.Config.have_gpu else self.to_numpy(copy)
9999

100100

101+
def _is_host_accessible(cls):
102+
"""Check if a PODVector type's allocator provides host-accessible memory.
103+
104+
On CPU builds all allocators are host-accessible. On GPU builds only
105+
``std``, ``pinned`` and ``managed`` allocators live in host memory;
106+
the rest (``arena``, ``device``, ``async``, ...) are device-only.
107+
"""
108+
import inspect
109+
110+
amr = inspect.getmodule(cls)
111+
if not amr.Config.have_gpu:
112+
return True
113+
suffix = cls.__name__.rsplit("_", 1)[-1]
114+
return suffix in ("std", "pinned", "managed")
115+
116+
117+
def podvector_from_numpy(cls, arr):
118+
"""
119+
Create a new PODVector from a NumPy array (or array-like).
120+
121+
Always copies the data into a newly allocated PODVector.
122+
Only works for allocator types with host-accessible memory (e.g.,
123+
``std``, ``pinned``). For device-only allocators, use
124+
:meth:`from_cupy` or :meth:`from_xp` instead.
125+
126+
Parameters
127+
----------
128+
cls : type
129+
The PODVector type to construct.
130+
arr : array_like
131+
Input data, convertible to a NumPy array.
132+
133+
Returns
134+
-------
135+
PODVector
136+
A new PODVector with a copy of the data.
137+
138+
Raises
139+
------
140+
TypeError
141+
If the allocator is not host-accessible.
142+
"""
143+
import numpy as np
144+
145+
n = len(arr)
146+
if n == 0:
147+
return cls()
148+
149+
if not _is_host_accessible(cls):
150+
raise TypeError(
151+
f"{cls.__name__} is not host-accessible. "
152+
"Use from_cupy() or from_xp() instead."
153+
)
154+
155+
pv = cls(n)
156+
np.array(pv, copy=False)[:] = arr
157+
return pv
158+
159+
160+
def podvector_from_cupy(cls, arr):
161+
"""
162+
Create a new PODVector from a CuPy array (or array-like).
163+
164+
Always copies the data into a newly allocated PODVector.
165+
Works for every allocator type: for host-only allocators the
166+
data is staged to the host through NumPy automatically.
167+
168+
Parameters
169+
----------
170+
cls : type
171+
The PODVector type to construct.
172+
arr : array_like
173+
Input data, convertible to a CuPy array.
174+
175+
Returns
176+
-------
177+
PODVector
178+
A new PODVector with a copy of the data.
179+
"""
180+
import cupy as cp
181+
182+
n = len(arr)
183+
if n == 0:
184+
return cls()
185+
pv = cls(n)
186+
if _is_host_accessible(cls):
187+
import numpy as np
188+
189+
np.array(pv, copy=False)[:] = cp.asnumpy(arr)
190+
else:
191+
cp.asarray(pv)[:] = arr
192+
return pv
193+
194+
195+
def podvector_from_xp(cls, arr):
196+
"""
197+
Create a new PODVector from a NumPy or CuPy array,
198+
depending on amr.Config.have_gpu .
199+
200+
Always copies the data into a newly allocated PODVector.
201+
Unlike :meth:`to_xp`, a zero-copy view is not possible here because
202+
PODVector always owns its memory through its allocator.
203+
204+
This function is similar to CuPy's xp naming suggestion for CPU/GPU agnostic code:
205+
https://docs.cupy.dev/en/stable/user_guide/basic.html#how-to-write-cpu-gpu-agnostic-code
206+
207+
Parameters
208+
----------
209+
cls : type
210+
The PODVector type to construct.
211+
arr : array_like
212+
Input data (NumPy or CuPy array).
213+
214+
Returns
215+
-------
216+
PODVector
217+
A new PODVector with a copy of the data.
218+
"""
219+
import inspect
220+
221+
amr = inspect.getmodule(cls)
222+
return cls.from_cupy(arr) if amr.Config.have_gpu else cls.from_numpy(arr)
223+
224+
101225
def register_PODVector_extension(amr):
102226
"""PODVector helper methods"""
103227
import inspect
@@ -112,6 +236,12 @@ def register_PODVector_extension(amr):
112236
and member.__name__.startswith("PODVector_")
113237
),
114238
):
239+
# instance methods: PODVector -> array
115240
POD_type.to_numpy = podvector_to_numpy
116241
POD_type.to_cupy = podvector_to_cupy
117242
POD_type.to_xp = podvector_to_xp
243+
244+
# class methods: array -> PODVector
245+
POD_type.from_numpy = classmethod(podvector_from_numpy)
246+
POD_type.from_cupy = classmethod(podvector_from_cupy)
247+
POD_type.from_xp = classmethod(podvector_from_xp)

tests/test_podvector.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,49 @@ def test_array_interface():
4141

4242
podv[1] = 5
4343
assert arr[1] == podv[1] == 5
44+
45+
46+
def test_from_numpy():
47+
import numpy as np
48+
49+
# basic roundtrip
50+
arr = np.array([1.0, 2.5, 3.7, 4.0], dtype=np.float64)
51+
podv = amr.PODVector_real_std.from_numpy(arr)
52+
assert podv.size() == 4
53+
result = podv.to_numpy()
54+
np.testing.assert_array_equal(result, arr)
55+
56+
# from_numpy creates a copy, not a view
57+
arr[0] = 999.0
58+
assert podv[0] != 999.0
59+
60+
# empty array
61+
empty = np.array([], dtype=np.float64)
62+
podv_empty = amr.PODVector_real_std.from_numpy(empty)
63+
assert podv_empty.size() == 0
64+
65+
# from list (array-like)
66+
podv_list = amr.PODVector_real_std.from_numpy([10.0, 20.0])
67+
assert podv_list.size() == 2
68+
assert podv_list[1] == 20.0
69+
70+
71+
def test_from_xp():
72+
import numpy as np
73+
74+
arr = np.array([1.0, 2.0, 3.0])
75+
podv = amr.PODVector_real_std.from_xp(arr)
76+
assert podv.size() == 3
77+
result = podv.to_numpy()
78+
np.testing.assert_array_equal(result, arr)
79+
80+
81+
def test_from_xp_default():
82+
"""Test from_xp on the default (platform-adaptive) PODVector type alias."""
83+
import numpy as np
84+
85+
arr = np.array([5.0, 6.0, 7.0])
86+
podv = amr.PODVector_real_default.from_xp(arr)
87+
assert podv.size() == 3
88+
result = podv.to_numpy(copy=True)
89+
np.testing.assert_array_equal(result, arr)

0 commit comments

Comments
 (0)