Update dtype and type hints for numpy arrays

这个提交包含在:
nmannall
2024-07-18 17:10:39 +01:00
父节点 63042ab0ad
当前提交 b8e7a3b3ca
共有 4 个文件被更改,包括 27 次插入27 次删除

查看文件

@@ -24,6 +24,7 @@ from collections import OrderedDict
from typing import Any, Iterable, List, Tuple, Union
import numpy as np
import numpy.typing as npt
from terminaltables import AsciiTable
from tqdm import tqdm
from typing_extensions import TypeVar
@@ -61,28 +62,28 @@ class FDTDGrid:
self.dt = 0.0
# Field Arrays
self.Ex: np.ndarray[Any, np.dtype[np.single]]
self.Ey: np.ndarray[Any, np.dtype[np.single]]
self.Ez: np.ndarray[Any, np.dtype[np.single]]
self.Hx: np.ndarray[Any, np.dtype[np.single]]
self.Hy: np.ndarray[Any, np.dtype[np.single]]
self.Hz: np.ndarray[Any, np.dtype[np.single]]
self.Ex: npt.NDArray[np.single]
self.Ey: npt.NDArray[np.single]
self.Ez: npt.NDArray[np.single]
self.Hx: npt.NDArray[np.single]
self.Hy: npt.NDArray[np.single]
self.Hz: npt.NDArray[np.single]
# Dispersive Arrays
self.Tx: np.ndarray[Any, np.dtype[np.single]]
self.Ty: np.ndarray[Any, np.dtype[np.single]]
self.Tz: np.ndarray[Any, np.dtype[np.single]]
self.Tx: npt.NDArray[np.single]
self.Ty: npt.NDArray[np.single]
self.Tz: npt.NDArray[np.single]
# Geometry Arrays
self.solid: np.ndarray[Any, np.dtype[np.uint32]]
self.rigidE: np.ndarray[Any, np.dtype[np.int8]]
self.rigidH: np.ndarray[Any, np.dtype[np.int8]]
self.ID: np.ndarray[Any, np.dtype[np.uint32]]
self.solid: npt.NDArray[np.uint32]
self.rigidE: npt.NDArray[np.int8]
self.rigidH: npt.NDArray[np.int8]
self.ID: npt.NDArray[np.uint32]
# Update Coefficient Arrays
self.updatecoeffsE: np.ndarray
self.updatecoeffsH: np.ndarray
self.updatecoeffsdispersive: np.ndarray
self.updatecoeffsE: npt.NDArray
self.updatecoeffsH: npt.NDArray
self.updatecoeffsdispersive: npt.NDArray
# PML parameters - set some defaults to use if not user provided
self.pmls = {}

查看文件

@@ -56,7 +56,7 @@ class MPIGrid(FDTDGrid):
COORDINATOR_RANK = 0
def __init__(self, comm: MPI.Cartcomm):
self.size = np.zeros(3, dtype=int)
self.size = np.zeros(3, dtype=np.intc)
super().__init__()
@@ -66,12 +66,12 @@ class MPIGrid(FDTDGrid):
self.z_comm = comm.Sub([True, True, False])
self.pml_comm = MPI.COMM_NULL
self.mpi_tasks = np.array(self.comm.dims)
self.mpi_tasks = np.array(self.comm.dims, dtype=np.intc)
self.lower_extent: npt.NDArray[np.intc] = np.zeros(3, dtype=int)
self.upper_extent: npt.NDArray[np.intc] = np.zeros(3, dtype=int)
self.negative_halo_offset: npt.NDArray[np.bool_] = np.zeros(3, dtype=int)
self.global_size: npt.NDArray[np.intc] = np.zeros(3, dtype=int)
self.lower_extent = np.zeros(3, dtype=np.intc)
self.upper_extent = np.zeros(3, dtype=np.intc)
self.negative_halo_offset = np.zeros(3, dtype=np.bool_)
self.global_size = np.zeros(3, dtype=np.intc)
self.neighbours = np.full((3, 2), -1, dtype=int)
self.neighbours[Dim.X] = self.comm.Shift(direction=Dim.X, disp=1)

查看文件

@@ -17,7 +17,6 @@
# along with gprMax. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
import numpy.typing as npt
import gprMax.config as config
@@ -32,8 +31,8 @@ class Rx:
def __init__(self):
self.ID: str
self.outputs = {}
self.coord: npt.NDArray[np.int_] = np.zeros(3, dtype=int)
self.coordorigin: npt.NDArray[np.int_] = np.zeros(3, dtype=int)
self.coord = np.zeros(3, dtype=np.intc)
self.coordorigin = np.zeros(3, dtype=np.intc)
@property
def xcoord(self) -> int:

查看文件

@@ -33,8 +33,8 @@ class Source:
def __init__(self):
self.ID: str
self.polarisation = None
self.coord: npt.NDArray[np.int_] = np.zeros(3, dtype=int)
self.coordorigin: npt.NDArray[np.int_] = np.zeros(3, dtype=int)
self.coord = np.zeros(3, dtype=np.intc)
self.coordorigin = np.zeros(3, dtype=np.intc)
self.start = 0.0
self.stop = 0.0
self.waveform: Waveform