Update dtype and type hints for numpy arrays

这个提交包含在:
nmannall
2024-07-18 17:10:39 +01:00
父节点 63042ab0ad
当前提交 b8e7a3b3ca
共有 4 个文件被更改,包括 27 次插入27 次删除

查看文件

@@ -24,6 +24,7 @@ from collections import OrderedDict
from typing import Any, Iterable, List, Tuple, Union from typing import Any, Iterable, List, Tuple, Union
import numpy as np import numpy as np
import numpy.typing as npt
from terminaltables import AsciiTable from terminaltables import AsciiTable
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import TypeVar from typing_extensions import TypeVar
@@ -61,28 +62,28 @@ class FDTDGrid:
self.dt = 0.0 self.dt = 0.0
# Field Arrays # Field Arrays
self.Ex: np.ndarray[Any, np.dtype[np.single]] self.Ex: npt.NDArray[np.single]
self.Ey: np.ndarray[Any, np.dtype[np.single]] self.Ey: npt.NDArray[np.single]
self.Ez: np.ndarray[Any, np.dtype[np.single]] self.Ez: npt.NDArray[np.single]
self.Hx: np.ndarray[Any, np.dtype[np.single]] self.Hx: npt.NDArray[np.single]
self.Hy: np.ndarray[Any, np.dtype[np.single]] self.Hy: npt.NDArray[np.single]
self.Hz: np.ndarray[Any, np.dtype[np.single]] self.Hz: npt.NDArray[np.single]
# Dispersive Arrays # Dispersive Arrays
self.Tx: np.ndarray[Any, np.dtype[np.single]] self.Tx: npt.NDArray[np.single]
self.Ty: np.ndarray[Any, np.dtype[np.single]] self.Ty: npt.NDArray[np.single]
self.Tz: np.ndarray[Any, np.dtype[np.single]] self.Tz: npt.NDArray[np.single]
# Geometry Arrays # Geometry Arrays
self.solid: np.ndarray[Any, np.dtype[np.uint32]] self.solid: npt.NDArray[np.uint32]
self.rigidE: np.ndarray[Any, np.dtype[np.int8]] self.rigidE: npt.NDArray[np.int8]
self.rigidH: np.ndarray[Any, np.dtype[np.int8]] self.rigidH: npt.NDArray[np.int8]
self.ID: np.ndarray[Any, np.dtype[np.uint32]] self.ID: npt.NDArray[np.uint32]
# Update Coefficient Arrays # Update Coefficient Arrays
self.updatecoeffsE: np.ndarray self.updatecoeffsE: npt.NDArray
self.updatecoeffsH: np.ndarray self.updatecoeffsH: npt.NDArray
self.updatecoeffsdispersive: np.ndarray self.updatecoeffsdispersive: npt.NDArray
# PML parameters - set some defaults to use if not user provided # PML parameters - set some defaults to use if not user provided
self.pmls = {} self.pmls = {}

查看文件

@@ -56,7 +56,7 @@ class MPIGrid(FDTDGrid):
COORDINATOR_RANK = 0 COORDINATOR_RANK = 0
def __init__(self, comm: MPI.Cartcomm): def __init__(self, comm: MPI.Cartcomm):
self.size = np.zeros(3, dtype=int) self.size = np.zeros(3, dtype=np.intc)
super().__init__() super().__init__()
@@ -66,12 +66,12 @@ class MPIGrid(FDTDGrid):
self.z_comm = comm.Sub([True, True, False]) self.z_comm = comm.Sub([True, True, False])
self.pml_comm = MPI.COMM_NULL self.pml_comm = MPI.COMM_NULL
self.mpi_tasks = np.array(self.comm.dims) self.mpi_tasks = np.array(self.comm.dims, dtype=np.intc)
self.lower_extent: npt.NDArray[np.intc] = np.zeros(3, dtype=int) self.lower_extent = np.zeros(3, dtype=np.intc)
self.upper_extent: npt.NDArray[np.intc] = np.zeros(3, dtype=int) self.upper_extent = np.zeros(3, dtype=np.intc)
self.negative_halo_offset: npt.NDArray[np.bool_] = np.zeros(3, dtype=int) self.negative_halo_offset = np.zeros(3, dtype=np.bool_)
self.global_size: npt.NDArray[np.intc] = np.zeros(3, dtype=int) self.global_size = np.zeros(3, dtype=np.intc)
self.neighbours = np.full((3, 2), -1, dtype=int) self.neighbours = np.full((3, 2), -1, dtype=int)
self.neighbours[Dim.X] = self.comm.Shift(direction=Dim.X, disp=1) self.neighbours[Dim.X] = self.comm.Shift(direction=Dim.X, disp=1)

查看文件

@@ -17,7 +17,6 @@
# along with gprMax. If not, see <http://www.gnu.org/licenses/>. # along with gprMax. If not, see <http://www.gnu.org/licenses/>.
import numpy as np import numpy as np
import numpy.typing as npt
import gprMax.config as config import gprMax.config as config
@@ -32,8 +31,8 @@ class Rx:
def __init__(self): def __init__(self):
self.ID: str self.ID: str
self.outputs = {} self.outputs = {}
self.coord: npt.NDArray[np.int_] = np.zeros(3, dtype=int) self.coord = np.zeros(3, dtype=np.intc)
self.coordorigin: npt.NDArray[np.int_] = np.zeros(3, dtype=int) self.coordorigin = np.zeros(3, dtype=np.intc)
@property @property
def xcoord(self) -> int: def xcoord(self) -> int:

查看文件

@@ -33,8 +33,8 @@ class Source:
def __init__(self): def __init__(self):
self.ID: str self.ID: str
self.polarisation = None self.polarisation = None
self.coord: npt.NDArray[np.int_] = np.zeros(3, dtype=int) self.coord = np.zeros(3, dtype=np.intc)
self.coordorigin: npt.NDArray[np.int_] = np.zeros(3, dtype=int) self.coordorigin = np.zeros(3, dtype=np.intc)
self.start = 0.0 self.start = 0.0
self.stop = 0.0 self.stop = 0.0
self.waveform: Waveform self.waveform: Waveform