|
11 | 11 | import typing |
12 | 12 | from collections.abc import Callable, Sequence |
13 | 13 | from functools import cached_property, singledispatch |
| 14 | +from typing import Generic, TypeVar |
14 | 15 |
|
15 | 16 | import numpy as np |
16 | 17 | import numpy.typing as npt |
|
28 | 29 |
|
29 | 30 | from dolfinx.mesh import Mesh |
30 | 31 |
|
| 32 | +_S = TypeVar("_S", np.float32, np.float64, np.complex64, np.complex128) # scalar |
31 | 33 |
|
32 | | -class Constant(ufl.Constant): |
| 34 | + |
| 35 | +class Constant(ufl.Constant, Generic[_S]): |
33 | 36 | """A constant with respect to a domain.""" |
34 | 37 |
|
35 | 38 | _cpp_object: ( |
@@ -72,27 +75,27 @@ def value(self): |
72 | 75 | return self._cpp_object.value |
73 | 76 |
|
74 | 77 | @value.setter |
75 | | - def value(self, v): |
| 78 | + def value(self, v: npt.NDArray[_S]) -> None: |
76 | 79 | np.copyto(self._cpp_object.value, np.asarray(v)) |
77 | 80 |
|
78 | 81 | @property |
79 | | - def dtype(self) -> np.dtype: |
| 82 | + def dtype(self) -> npt.DTypeLike: |
80 | 83 | """Value dtype of the constant.""" |
81 | 84 | return np.dtype(self._cpp_object.dtype) |
82 | 85 |
|
83 | | - def __float__(self): |
| 86 | + def __float__(self) -> float: |
84 | 87 | """Real representation of the constant.""" |
85 | 88 | if self.ufl_shape or self.ufl_free_indices: |
86 | 89 | raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.") |
87 | | - else: |
88 | | - return float(self.value) |
89 | 90 |
|
90 | | - def __complex__(self): |
| 91 | + return float(self.value) |
| 92 | + |
| 93 | + def __complex__(self) -> complex: |
91 | 94 | """Complex representation of the constant.""" |
92 | 95 | if self.ufl_shape or self.ufl_free_indices: |
93 | 96 | raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.") |
94 | | - else: |
95 | | - return complex(self.value) |
| 97 | + |
| 98 | + return complex(self.value) |
96 | 99 |
|
97 | 100 |
|
98 | 101 | class Expression: |
|
0 commit comments