Skip to content

Commit b3072ab

Browse files
committed
Add typing generics to constant
1 parent e51eacc commit b3072ab

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

python/dolfinx/fem/function.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import typing
1212
from collections.abc import Callable, Sequence
1313
from functools import cached_property, singledispatch
14+
from typing import Generic, TypeVar
1415

1516
import numpy as np
1617
import numpy.typing as npt
@@ -28,8 +29,10 @@
2829

2930
from dolfinx.mesh import Mesh
3031

32+
_S = TypeVar("_S", np.float32, np.float64, np.complex64, np.complex128) # scalar
3133

32-
class Constant(ufl.Constant):
34+
35+
class Constant(ufl.Constant, Generic[_S]):
3336
"""A constant with respect to a domain."""
3437

3538
_cpp_object: (
@@ -72,27 +75,27 @@ def value(self):
7275
return self._cpp_object.value
7376

7477
@value.setter
75-
def value(self, v):
78+
def value(self, v: npt.NDArray[_S]) -> None:
7679
np.copyto(self._cpp_object.value, np.asarray(v))
7780

7881
@property
79-
def dtype(self) -> np.dtype:
82+
def dtype(self) -> npt.DTypeLike:
8083
"""Value dtype of the constant."""
8184
return np.dtype(self._cpp_object.dtype)
8285

83-
def __float__(self):
86+
def __float__(self) -> float:
8487
"""Real representation of the constant."""
8588
if self.ufl_shape or self.ufl_free_indices:
8689
raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.")
87-
else:
88-
return float(self.value)
8990

90-
def __complex__(self):
91+
return float(self.value)
92+
93+
def __complex__(self) -> complex:
9194
"""Complex representation of the constant."""
9295
if self.ufl_shape or self.ufl_free_indices:
9396
raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.")
94-
else:
95-
return complex(self.value)
97+
98+
return complex(self.value)
9699

97100

98101
class Expression:

0 commit comments

Comments
 (0)