Make UserObjectMulti and UserObjectSingle abstract

这个提交包含在:
nmannall
2024-04-18 17:34:54 +01:00
父节点 00d2c3c06e
当前提交 bfc4068fec
共有 2 个文件被更改,包括 243 次插入68 次删除

查看文件

@@ -18,14 +18,21 @@
import inspect import inspect
import logging import logging
from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from scipy import interpolate from scipy import interpolate
import gprMax.config as config import gprMax.config as config
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.user_inputs import UserInput
from .cmds_geometry.cmds_geometry import UserObjectGeometry, rotate_2point_object, rotate_polarisation from .cmds_geometry.cmds_geometry import (
UserObjectGeometry,
rotate_2point_object,
rotate_polarisation,
)
from .geometry_outputs import GeometryObjects as GeometryObjectsUser from .geometry_outputs import GeometryObjects as GeometryObjectsUser
from .materials import DispersiveMaterial as DispersiveMaterialUser from .materials import DispersiveMaterial as DispersiveMaterialUser
from .materials import ListMaterial as ListMaterialUser from .materials import ListMaterial as ListMaterialUser
@@ -46,12 +53,12 @@ from .waveforms import Waveform as WaveformUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserObjectMulti: class UserObjectMulti(ABC):
"""Object that can occur multiple times in a model.""" """Object that can occur multiple times in a model."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.kwargs = kwargs self.kwargs = kwargs
self.order = None self.order = 0
self.hash = None self.hash = None
self.autotranslate = True self.autotranslate = True
self.do_rotate = False self.do_rotate = False
@@ -66,10 +73,12 @@ class UserObjectMulti:
return f"{self.hash}: {s[:-1]}" return f"{self.hash}: {s[:-1]}"
def build(self, grid, uip): @abstractmethod
def build(self, grid: FDTDGrid, uip: UserInput):
"""Creates object and adds it to grid.""" """Creates object and adds it to grid."""
pass pass
@abstractmethod
def rotate(self, axis, angle, origin=None): def rotate(self, axis, angle, origin=None):
"""Rotates object (specialised for each object).""" """Rotates object (specialised for each object)."""
pass pass
@@ -133,7 +142,9 @@ class ExcitationFile(UserObjectMulti):
waveformIDs = np.loadtxt(excitationfile, max_rows=1, dtype=str) waveformIDs = np.loadtxt(excitationfile, max_rows=1, dtype=str)
# Read all waveform values into an array # Read all waveform values into an array
waveformvalues = np.loadtxt(excitationfile, skiprows=1, dtype=config.sim_config.dtypes["float_or_double"]) waveformvalues = np.loadtxt(
excitationfile, skiprows=1, dtype=config.sim_config.dtypes["float_or_double"]
)
# Time array (if specified) for interpolation, otherwise use simulation time # Time array (if specified) for interpolation, otherwise use simulation time
if waveformIDs[0].lower() == "time": if waveformIDs[0].lower() == "time":
@@ -154,7 +165,9 @@ class ExcitationFile(UserObjectMulti):
w.type = "user" w.type = "user"
# Select correct column of waveform values depending on array shape # Select correct column of waveform values depending on array shape
singlewaveformvalues = waveformvalues[:] if len(waveformvalues.shape) == 1 else waveformvalues[:, i] singlewaveformvalues = (
waveformvalues[:] if len(waveformvalues.shape) == 1 else waveformvalues[:, i]
)
# Truncate waveform array if it is longer than time array # Truncate waveform array if it is longer than time array
if len(singlewaveformvalues) > len(waveformtime): if len(singlewaveformvalues) > len(waveformtime):
@@ -220,11 +233,14 @@ class Waveform(UserObjectMulti):
freq = self.kwargs["freq"] freq = self.kwargs["freq"]
ID = self.kwargs["id"] ID = self.kwargs["id"]
except KeyError: except KeyError:
logger.exception(self.params_str() + (" builtin waveforms " "require exactly four parameters.")) logger.exception(
self.params_str() + (" builtin waveforms " "require exactly four parameters.")
)
raise raise
if freq <= 0: if freq <= 0:
logger.exception( logger.exception(
self.params_str() + (" requires an excitation " "frequency value of greater than zero.") self.params_str()
+ (" requires an excitation " "frequency value of greater than zero.")
) )
raise ValueError raise ValueError
if any(x.ID == ID for x in grid.waveforms): if any(x.ID == ID for x in grid.waveforms):
@@ -253,7 +269,10 @@ class Waveform(UserObjectMulti):
fullargspec = inspect.getfullargspec(interpolate.interp1d) fullargspec = inspect.getfullargspec(interpolate.interp1d)
kwargs = dict(zip(reversed(fullargspec.args), reversed(fullargspec.defaults))) kwargs = dict(zip(reversed(fullargspec.args), reversed(fullargspec.defaults)))
except KeyError: except KeyError:
logger.exception(self.params_str() + (" a user-defined " "waveform requires at least two parameters.")) logger.exception(
self.params_str()
+ (" a user-defined " "waveform requires at least two parameters.")
)
raise raise
if "user_time" in self.kwargs: if "user_time" in self.kwargs:
@@ -276,7 +295,9 @@ class Waveform(UserObjectMulti):
w.type = wavetype w.type = wavetype
w.userfunc = interpolate.interp1d(waveformtime, uservalues, **kwargs) w.userfunc = interpolate.interp1d(waveformtime, uservalues, **kwargs)
logger.info(self.grid_name(grid) + (f"Waveform {w.ID} that is " "user-defined created.")) logger.info(
self.grid_name(grid) + (f"Waveform {w.ID} that is " "user-defined created.")
)
grid.waveforms.append(w) grid.waveforms.append(w)
@@ -354,12 +375,16 @@ class VoltageSource(UserObjectMulti):
p2 = uip.round_to_grid_static_point(p1) p2 = uip.round_to_grid_static_point(p1)
if resistance < 0: if resistance < 0:
logger.exception(self.params_str() + (" requires a source " "resistance of zero " "or greater.")) logger.exception(
self.params_str() + (" requires a source " "resistance of zero " "or greater.")
)
raise ValueError raise ValueError
# Check if there is a waveformID in the waveforms list # Check if there is a waveformID in the waveforms list
if not any(x.ID == waveform_id for x in grid.waveforms): if not any(x.ID == waveform_id for x in grid.waveforms):
logger.exception(self.params_str() + (" there is no waveform with " "the identifier {waveform_id}.")) logger.exception(
self.params_str() + (" there is no waveform with " "the identifier {waveform_id}.")
)
raise ValueError raise ValueError
v = VoltageSourceUser() v = VoltageSourceUser()
@@ -367,7 +392,16 @@ class VoltageSource(UserObjectMulti):
v.xcoord = xcoord v.xcoord = xcoord
v.ycoord = ycoord v.ycoord = ycoord
v.zcoord = zcoord v.zcoord = zcoord
v.ID = v.__class__.__name__ + "(" + str(v.xcoord) + "," + str(v.ycoord) + "," + str(v.zcoord) + ")" v.ID = (
v.__class__.__name__
+ "("
+ str(v.xcoord)
+ ","
+ str(v.ycoord)
+ ","
+ str(v.zcoord)
+ ")"
)
v.resistance = resistance v.resistance = resistance
v.waveformID = waveform_id v.waveformID = waveform_id
@@ -377,14 +411,21 @@ class VoltageSource(UserObjectMulti):
# Check source start & source remove time parameters # Check source start & source remove time parameters
if start < 0: if start < 0:
logger.exception( logger.exception(
self.params_str() + (" delay of the initiation " "of the source should not " "be less than zero.") self.params_str()
+ (" delay of the initiation " "of the source should not " "be less than zero.")
) )
raise ValueError raise ValueError
if stop < 0: if stop < 0:
logger.exception(self.params_str() + (" time to remove the " "source should not be " "less than zero.")) logger.exception(
self.params_str()
+ (" time to remove the " "source should not be " "less than zero.")
)
raise ValueError raise ValueError
if stop - start <= 0: if stop - start <= 0:
logger.exception(self.params_str() + (" duration of the source " "should not be zero or " "less.")) logger.exception(
self.params_str()
+ (" duration of the source " "should not be zero or " "less.")
)
raise ValueError raise ValueError
v.start = start v.start = start
v.stop = min(stop, grid.timewindow) v.stop = min(stop, grid.timewindow)
@@ -399,7 +440,9 @@ class VoltageSource(UserObjectMulti):
logger.info( logger.info(
f"{self.grid_name(grid)}Voltage source with polarity " f"{self.grid_name(grid)}Voltage source with polarity "
f"{v.polarisation} at {p2[0]:g}m, {p2[1]:g}m, {p2[2]:g}m, " f"{v.polarisation} at {p2[0]:g}m, {p2[1]:g}m, {p2[2]:g}m, "
f"resistance {v.resistance:.1f} Ohms," + startstop + f"using waveform {v.waveformID} created." f"resistance {v.resistance:.1f} Ohms,"
+ startstop
+ f"using waveform {v.waveformID} created."
) )
grid.voltagesources.append(v) grid.voltagesources.append(v)
@@ -478,7 +521,9 @@ class HertzianDipole(UserObjectMulti):
# Check if there is a waveformID in the waveforms list # Check if there is a waveformID in the waveforms list
if not any(x.ID == waveform_id for x in grid.waveforms): if not any(x.ID == waveform_id for x in grid.waveforms):
logger.exception(f"{self.params_str()} there is no waveform with the identifier {waveform_id}.") logger.exception(
f"{self.params_str()} there is no waveform with the identifier {waveform_id}."
)
raise ValueError raise ValueError
h = HertzianDipoleUser() h = HertzianDipoleUser()
@@ -511,10 +556,14 @@ class HertzianDipole(UserObjectMulti):
) )
raise ValueError raise ValueError
if stop < 0: if stop < 0:
logger.exception(f"{self.params_str()} time to remove the source should not be less than zero.") logger.exception(
f"{self.params_str()} time to remove the source should not be less than zero."
)
raise ValueError raise ValueError
if stop - start <= 0: if stop - start <= 0:
logger.exception(f"{self.params_str()} duration of the source should not be zero or less.") logger.exception(
f"{self.params_str()} duration of the source should not be zero or less."
)
raise ValueError raise ValueError
h.start = start h.start = start
h.stop = min(stop, grid.timewindow) h.stop = min(stop, grid.timewindow)
@@ -619,7 +668,9 @@ class MagneticDipole(UserObjectMulti):
# Check if there is a waveformID in the waveforms list # Check if there is a waveformID in the waveforms list
if not any(x.ID == waveform_id for x in grid.waveforms): if not any(x.ID == waveform_id for x in grid.waveforms):
logger.exception(f"{self.params_str()} there is no waveform with the identifier {waveform_id}.") logger.exception(
f"{self.params_str()} there is no waveform with the identifier {waveform_id}."
)
raise ValueError raise ValueError
m = MagneticDipoleUser() m = MagneticDipoleUser()
@@ -630,7 +681,16 @@ class MagneticDipole(UserObjectMulti):
m.xcoordorigin = xcoord m.xcoordorigin = xcoord
m.ycoordorigin = ycoord m.ycoordorigin = ycoord
m.zcoordorigin = zcoord m.zcoordorigin = zcoord
m.ID = m.__class__.__name__ + "(" + str(m.xcoord) + "," + str(m.ycoord) + "," + str(m.zcoord) + ")" m.ID = (
m.__class__.__name__
+ "("
+ str(m.xcoord)
+ ","
+ str(m.ycoord)
+ ","
+ str(m.zcoord)
+ ")"
)
m.waveformID = waveform_id m.waveformID = waveform_id
try: try:
@@ -639,14 +699,21 @@ class MagneticDipole(UserObjectMulti):
stop = self.kwargs["stop"] stop = self.kwargs["stop"]
if start < 0: if start < 0:
logger.exception( logger.exception(
self.params_str() + (" delay of the initiation " "of the source should not " "be less than zero.") self.params_str()
+ (" delay of the initiation " "of the source should not " "be less than zero.")
) )
raise ValueError raise ValueError
if stop < 0: if stop < 0:
logger.exception(self.params_str() + (" time to remove the " "source should not be " "less than zero.")) logger.exception(
self.params_str()
+ (" time to remove the " "source should not be " "less than zero.")
)
raise ValueError raise ValueError
if stop - start <= 0: if stop - start <= 0:
logger.exception(self.params_str() + (" duration of the source " "should not be zero or " "less.")) logger.exception(
self.params_str()
+ (" duration of the source " "should not be zero or " "less.")
)
raise ValueError raise ValueError
m.start = start m.start = start
m.stop = min(stop, grid.timewindow) m.stop = min(stop, grid.timewindow)
@@ -760,7 +827,9 @@ class TransmissionLine(UserObjectMulti):
# Check if there is a waveformID in the waveforms list # Check if there is a waveformID in the waveforms list
if not any(x.ID == waveform_id for x in grid.waveforms): if not any(x.ID == waveform_id for x in grid.waveforms):
logger.exception(f"{self.params_str()} there is no waveform with the identifier {waveform_id}.") logger.exception(
f"{self.params_str()} there is no waveform with the identifier {waveform_id}."
)
raise ValueError raise ValueError
t = TransmissionLineUser(grid) t = TransmissionLineUser(grid)
@@ -768,7 +837,16 @@ class TransmissionLine(UserObjectMulti):
t.xcoord = xcoord t.xcoord = xcoord
t.ycoord = ycoord t.ycoord = ycoord
t.zcoord = zcoord t.zcoord = zcoord
t.ID = t.__class__.__name__ + "(" + str(t.xcoord) + "," + str(t.ycoord) + "," + str(t.zcoord) + ")" t.ID = (
t.__class__.__name__
+ "("
+ str(t.xcoord)
+ ","
+ str(t.ycoord)
+ ","
+ str(t.zcoord)
+ ")"
)
t.resistance = resistance t.resistance = resistance
t.waveformID = waveform_id t.waveformID = waveform_id
@@ -778,14 +856,21 @@ class TransmissionLine(UserObjectMulti):
stop = self.kwargs["stop"] stop = self.kwargs["stop"]
if start < 0: if start < 0:
logger.exception( logger.exception(
self.params_str() + (" delay of the initiation " "of the source should not " "be less than zero.") self.params_str()
+ (" delay of the initiation " "of the source should not " "be less than zero.")
) )
raise ValueError raise ValueError
if stop < 0: if stop < 0:
logger.exception(self.params_str() + (" time to remove the " "source should not be " "less than zero.")) logger.exception(
self.params_str()
+ (" time to remove the " "source should not be " "less than zero.")
)
raise ValueError raise ValueError
if stop - start <= 0: if stop - start <= 0:
logger.exception(self.params_str() + (" duration of the source " "should not be zero or " "less.")) logger.exception(
self.params_str()
+ (" duration of the source " "should not be zero or " "less.")
)
raise ValueError raise ValueError
t.start = start t.start = start
t.stop = min(stop, grid.timewindow) t.stop = min(stop, grid.timewindow)
@@ -837,7 +922,11 @@ class Rx(UserObjectMulti):
def _do_rotate(self, grid): def _do_rotate(self, grid):
"""Performs rotation.""" """Performs rotation."""
new_pt = (self.kwargs["p1"][0] + grid.dx, self.kwargs["p1"][1] + grid.dy, self.kwargs["p1"][2] + grid.dz) new_pt = (
self.kwargs["p1"][0] + grid.dx,
self.kwargs["p1"][1] + grid.dy,
self.kwargs["p1"][2] + grid.dz,
)
pts = np.array([self.kwargs["p1"], new_pt]) pts = np.array([self.kwargs["p1"], new_pt])
rot_pts = rotate_2point_object(pts, self.axis, self.angle, self.origin) rot_pts = rotate_2point_object(pts, self.axis, self.angle, self.origin)
self.kwargs["p1"] = tuple(rot_pts[0, :]) self.kwargs["p1"] = tuple(rot_pts[0, :])
@@ -876,7 +965,9 @@ class Rx(UserObjectMulti):
# If no ID or outputs are specified, use default # If no ID or outputs are specified, use default
r.ID = f"{r.__class__.__name__}({str(r.xcoord)},{str(r.ycoord)},{str(r.zcoord)})" r.ID = f"{r.__class__.__name__}({str(r.xcoord)},{str(r.ycoord)},{str(r.zcoord)})"
for key in RxUser.defaultoutputs: for key in RxUser.defaultoutputs:
r.outputs[key] = np.zeros(grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]) r.outputs[key] = np.zeros(
grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]
)
else: else:
outputs.sort() outputs.sort()
# Get allowable outputs # Get allowable outputs
@@ -887,7 +978,9 @@ class Rx(UserObjectMulti):
# Check and add field output names # Check and add field output names
for field in outputs: for field in outputs:
if field in allowableoutputs: if field in allowableoutputs:
r.outputs[field] = np.zeros(grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]) r.outputs[field] = np.zeros(
grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]
)
else: else:
logger.exception( logger.exception(
f"{self.params_str()} contains an output " f"{self.params_str()} contains an output "
@@ -938,7 +1031,9 @@ class RxArray(UserObjectMulti):
dx, dy, dz = uip.discretise_point(dl) dx, dy, dz = uip.discretise_point(dl)
if xs > xf or ys > yf or zs > zf: if xs > xf or ys > yf or zs > zf:
logger.exception(f"{self.params_str()} the lower coordinates should be less than the upper coordinates.") logger.exception(
f"{self.params_str()} the lower coordinates should be less than the upper coordinates."
)
raise ValueError raise ValueError
if dx < 0 or dy < 0 or dz < 0: if dx < 0 or dy < 0 or dz < 0:
logger.exception(f"{self.params_str()} the step size should not be less than zero.") logger.exception(f"{self.params_str()} the step size should not be less than zero.")
@@ -991,7 +1086,9 @@ class RxArray(UserObjectMulti):
p5 = uip.round_to_grid_static_point(p5) p5 = uip.round_to_grid_static_point(p5)
r.ID = f"{r.__class__.__name__}({str(x)},{str(y)},{str(z)})" r.ID = f"{r.__class__.__name__}({str(x)},{str(y)},{str(z)})"
for key in RxUser.defaultoutputs: for key in RxUser.defaultoutputs:
r.outputs[key] = np.zeros(grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]) r.outputs[key] = np.zeros(
grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]
)
logger.info( logger.info(
f" Receiver at {p5[0]:g}m, {p5[1]:g}m, " f" Receiver at {p5[0]:g}m, {p5[1]:g}m, "
f"{p5[2]:g}m with output component(s) " f"{p5[2]:g}m with output component(s) "
@@ -1102,13 +1199,29 @@ class Snapshot(UserObjectMulti):
logger.exception(f"{self.params_str()} the step size should not be less than zero.") logger.exception(f"{self.params_str()} the step size should not be less than zero.")
raise ValueError raise ValueError
if dx < 1 or dy < 1 or dz < 1: if dx < 1 or dy < 1 or dz < 1:
logger.exception(f"{self.params_str()} the step size should not be less than the spatial discretisation.") logger.exception(
f"{self.params_str()} the step size should not be less than the spatial discretisation."
)
raise ValueError raise ValueError
if iterations <= 0 or iterations > grid.iterations: if iterations <= 0 or iterations > grid.iterations:
logger.exception(f"{self.params_str()} time value is not valid.") logger.exception(f"{self.params_str()} time value is not valid.")
raise ValueError raise ValueError
s = SnapshotUser(xs, ys, zs, xf, yf, zf, dx, dy, dz, iterations, filename, fileext=fileext, outputs=outputs) s = SnapshotUser(
xs,
ys,
zs,
xf,
yf,
zf,
dx,
dy,
dz,
iterations,
filename,
fileext=fileext,
outputs=outputs,
)
logger.info( logger.info(
f"Snapshot from {p3[0]:g}m, {p3[1]:g}m, {p3[2]:g}m, to " f"Snapshot from {p3[0]:g}m, {p3[1]:g}m, {p3[2]:g}m, to "
@@ -1158,7 +1271,9 @@ class Material(UserObjectMulti):
if se != "inf": if se != "inf":
se = float(se) se = float(se)
if se < 0: if se < 0:
logger.exception(f"{self.params_str()} requires a positive value for electric conductivity.") logger.exception(
f"{self.params_str()} requires a positive value for electric conductivity."
)
raise ValueError raise ValueError
else: else:
se = float("inf") se = float("inf")
@@ -1251,13 +1366,17 @@ class AddDebyeDispersion(UserObjectMulti):
disp_material.deltaer.append(er_delta[i]) disp_material.deltaer.append(er_delta[i])
disp_material.tau.append(tau[i]) disp_material.tau.append(tau[i])
else: else:
logger.exception(f"{self.params_str()} requires positive values for the permittivity difference.") logger.exception(
f"{self.params_str()} requires positive values for the permittivity difference."
)
raise ValueError raise ValueError
if disp_material.poles > config.get_model_config().materials["maxpoles"]: if disp_material.poles > config.get_model_config().materials["maxpoles"]:
config.get_model_config().materials["maxpoles"] = disp_material.poles config.get_model_config().materials["maxpoles"] = disp_material.poles
# Replace original material with newly created DispersiveMaterial # Replace original material with newly created DispersiveMaterial
grid.materials = [disp_material if mat.numID == material.numID else mat for mat in grid.materials] grid.materials = [
disp_material if mat.numID == material.numID else mat for mat in grid.materials
]
logger.info( logger.info(
f"{self.grid_name(grid)}Debye disperion added to {disp_material.ID} " f"{self.grid_name(grid)}Debye disperion added to {disp_material.ID} "
@@ -1336,7 +1455,9 @@ class AddLorentzDispersion(UserObjectMulti):
config.get_model_config().materials["maxpoles"] = disp_material.poles config.get_model_config().materials["maxpoles"] = disp_material.poles
# Replace original material with newly created DispersiveMaterial # Replace original material with newly created DispersiveMaterial
grid.materials = [disp_material if mat.numID == material.numID else mat for mat in grid.materials] grid.materials = [
disp_material if mat.numID == material.numID else mat for mat in grid.materials
]
logger.info( logger.info(
f"{self.grid_name(grid)}Lorentz disperion added to {disp_material.ID} " f"{self.grid_name(grid)}Lorentz disperion added to {disp_material.ID} "
@@ -1410,7 +1531,9 @@ class AddDrudeDispersion(UserObjectMulti):
config.get_model_config().materials["maxpoles"] = disp_material.poles config.get_model_config().materials["maxpoles"] = disp_material.poles
# Replace original material with newly created DispersiveMaterial # Replace original material with newly created DispersiveMaterial
grid.materials = [disp_material if mat.numID == material.numID else mat for mat in grid.materials] grid.materials = [
disp_material if mat.numID == material.numID else mat for mat in grid.materials
]
logger.info( logger.info(
f"{self.grid_name(grid)}Drude disperion added to {disp_material.ID} " f"{self.grid_name(grid)}Drude disperion added to {disp_material.ID} "
@@ -1454,16 +1577,22 @@ class SoilPeplinski(UserObjectMulti):
raise raise
if sand_fraction < 0: if sand_fraction < 0:
logger.exception(f"{self.params_str()} requires a positive value for the sand fraction.") logger.exception(
f"{self.params_str()} requires a positive value for the sand fraction."
)
raise ValueError raise ValueError
if clay_fraction < 0: if clay_fraction < 0:
logger.exception(f"{self.params_str()} requires a positive value for the clay fraction.") logger.exception(
f"{self.params_str()} requires a positive value for the clay fraction."
)
raise ValueError raise ValueError
if bulk_density < 0: if bulk_density < 0:
logger.exception(f"{self.params_str()} requires a positive value for the bulk density.") logger.exception(f"{self.params_str()} requires a positive value for the bulk density.")
raise ValueError raise ValueError
if sand_density < 0: if sand_density < 0:
logger.exception(f"{self.params_str()} requires a positive value for the sand particle density.") logger.exception(
f"{self.params_str()} requires a positive value for the sand particle density."
)
raise ValueError raise ValueError
if water_fraction_lower < 0: if water_fraction_lower < 0:
logger.exception( logger.exception(
@@ -1484,7 +1613,12 @@ class SoilPeplinski(UserObjectMulti):
# Create a new instance of the Material class material # Create a new instance of the Material class material
# (start index after pec & free_space) # (start index after pec & free_space)
s = PeplinskiSoilUser( s = PeplinskiSoilUser(
ID, sand_fraction, clay_fraction, bulk_density, sand_density, (water_fraction_lower, water_fraction_upper) ID,
sand_fraction,
clay_fraction,
bulk_density,
sand_density,
(water_fraction_lower, water_fraction_upper),
) )
logger.info( logger.info(
@@ -1546,10 +1680,14 @@ class MaterialRange(UserObjectMulti):
) )
raise ValueError raise ValueError
if sigma_lower < 0: if sigma_lower < 0:
logger.exception(f"{self.params_str()} requires a positive value for the lower limit of conductivity.") logger.exception(
f"{self.params_str()} requires a positive value for the lower limit of conductivity."
)
raise ValueError raise ValueError
if ro_lower < 0: if ro_lower < 0:
logger.exception(f"{self.params_str()} requires a positive value for the lower range magnetic loss.") logger.exception(
f"{self.params_str()} requires a positive value for the lower range magnetic loss."
)
raise ValueError raise ValueError
if er_upper < 1: if er_upper < 1:
logger.exception( logger.exception(
@@ -1564,17 +1702,25 @@ class MaterialRange(UserObjectMulti):
) )
raise ValueError raise ValueError
if sigma_upper < 0: if sigma_upper < 0:
logger.exception(f"{self.params_str()} requires a positive value for the upper range of conductivity.") logger.exception(
f"{self.params_str()} requires a positive value for the upper range of conductivity."
)
raise ValueError raise ValueError
if ro_upper < 0: if ro_upper < 0:
logger.exception(f"{self.params_str()} requires a positive value for the upper range of magnetic loss.") logger.exception(
f"{self.params_str()} requires a positive value for the upper range of magnetic loss."
)
if any(x.ID == ID for x in grid.mixingmodels): if any(x.ID == ID for x in grid.mixingmodels):
logger.exception(f"{self.params_str()} with ID {ID} already exists") logger.exception(f"{self.params_str()} with ID {ID} already exists")
raise ValueError raise ValueError
s = RangeMaterialUser( s = RangeMaterialUser(
ID, (er_lower, er_upper), (sigma_lower, sigma_upper), (mr_lower, mr_upper), (ro_lower, ro_upper) ID,
(er_lower, er_upper),
(sigma_lower, sigma_upper),
(mr_lower, mr_upper),
(ro_lower, ro_upper),
) )
logger.info( logger.info(
@@ -1614,7 +1760,9 @@ class MaterialList(UserObjectMulti):
s = ListMaterialUser(ID, list_of_materials) s = ListMaterialUser(ID, list_of_materials)
logger.info(f"{self.grid_name(grid)}A list of materials used to create {s.ID} that includes {s.mat}, created") logger.info(
f"{self.grid_name(grid)}A list of materials used to create {s.ID} that includes {s.mat}, created"
)
grid.mixingmodels.append(s) grid.mixingmodels.append(s)
@@ -1682,15 +1830,23 @@ class GeometryView(UserObjectMulti):
logger.exception(f"{self.params_str()} the step size should not be less than zero.") logger.exception(f"{self.params_str()} the step size should not be less than zero.")
raise ValueError raise ValueError
if dx > grid.nx or dy > grid.ny or dz > grid.nz: if dx > grid.nx or dy > grid.ny or dz > grid.nz:
logger.exception(f"{self.params_str()} the step size should be less than the domain size.") logger.exception(
f"{self.params_str()} the step size should be less than the domain size."
)
raise ValueError raise ValueError
if dx < 1 or dy < 1 or dz < 1: if dx < 1 or dy < 1 or dz < 1:
logger.exception(f"{self.params_str()} the step size should not be less than the spatial discretisation.") logger.exception(
f"{self.params_str()} the step size should not be less than the spatial discretisation."
)
raise ValueError raise ValueError
if output_type not in ["n", "f"]: if output_type not in ["n", "f"]:
logger.exception(f"{self.params_str()} requires type to be either n (normal) or f (fine).") logger.exception(
f"{self.params_str()} requires type to be either n (normal) or f (fine)."
)
raise ValueError raise ValueError
if output_type == "f" and (dx * grid.dx != grid.dx or dy * grid.dy != grid.dy or dz * grid.dz != grid.dz): if output_type == "f" and (
dx * grid.dx != grid.dx or dy * grid.dy != grid.dy or dz * grid.dz != grid.dz
):
logger.exception( logger.exception(
f"{self.params_str()} requires the spatial " f"{self.params_str()} requires the spatial "
"discretisation for the geometry view to be the " "discretisation for the geometry view to be the "
@@ -1818,7 +1974,9 @@ class PMLCFS(UserObjectMulti):
or kappascalingdirection not in CFSParameter.scalingdirections or kappascalingdirection not in CFSParameter.scalingdirections
or sigmascalingdirection not in CFSParameter.scalingdirections or sigmascalingdirection not in CFSParameter.scalingdirections
): ):
logger.exception(f"{self.params_str()} must have scaling type {','.join(CFSParameter.scalingdirections)}") logger.exception(
f"{self.params_str()} must have scaling type {','.join(CFSParameter.scalingdirections)}"
)
raise ValueError raise ValueError
if ( if (
float(alphamin) < 0 float(alphamin) < 0
@@ -1827,7 +1985,9 @@ class PMLCFS(UserObjectMulti):
or float(kappamax) < 0 or float(kappamax) < 0
or float(sigmamin) < 0 or float(sigmamin) < 0
): ):
logger.exception(f"{self.params_str()} minimum and maximum scaling values must be greater than zero.") logger.exception(
f"{self.params_str()} minimum and maximum scaling values must be greater than zero."
)
raise ValueError raise ValueError
cfsalpha = CFSParameter() cfsalpha = CFSParameter()
@@ -1871,7 +2031,9 @@ class PMLCFS(UserObjectMulti):
grid.pmls["cfs"].append(cfs) grid.pmls["cfs"].append(cfs)
if len(grid.pmls["cfs"]) > 2: if len(grid.pmls["cfs"]) > 2:
logger.exception(f"{self.params_str()} can only be used up to two times, for up to a 2nd order PML.") logger.exception(
f"{self.params_str()} can only be used up to two times, for up to a 2nd order PML."
)
raise ValueError raise ValueError

查看文件

@@ -17,10 +17,13 @@
# along with gprMax. If not, see <http://www.gnu.org/licenses/>. # along with gprMax. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
from abc import ABC, abstractmethod
import numpy as np import numpy as np
import gprMax.config as config import gprMax.config as config
from gprMax.grid.fdtd_grid import FDTDGrid
from gprMax.user_inputs import UserInput
from .pml import PML from .pml import PML
from .utilities.host_info import set_omp_threads from .utilities.host_info import set_omp_threads
@@ -32,14 +35,14 @@ class Properties:
pass pass
class UserObjectSingle: class UserObjectSingle(ABC):
"""Object that can only occur a single time in a model.""" """Object that can only occur a single time in a model."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Each single command has an order to specify the order in which # Each single command has an order to specify the order in which
# the commands are constructed, e.g. discretisation must be # the commands are constructed, e.g. discretisation must be
# created before the domain # created before the domain
self.order = None self.order = 0
self.kwargs = kwargs self.kwargs = kwargs
self.props = Properties() self.props = Properties()
self.autotranslate = True self.autotranslate = True
@@ -47,9 +50,11 @@ class UserObjectSingle:
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self.props, k, v) setattr(self.props, k, v)
def build(self, grid, uip): @abstractmethod
def build(self, grid: FDTDGrid, uip: UserInput):
pass pass
@abstractmethod
def rotate(self, axis, angle, origin=None): def rotate(self, axis, angle, origin=None):
pass pass
@@ -95,17 +100,20 @@ class Discretisation(UserObjectSingle):
if G.dl[0] <= 0: if G.dl[0] <= 0:
logger.exception( logger.exception(
f"{self.__str__()} discretisation requires the " f"x-direction spatial step to be greater than zero" f"{self.__str__()} discretisation requires the "
f"x-direction spatial step to be greater than zero"
) )
raise ValueError raise ValueError
if G.dl[1] <= 0: if G.dl[1] <= 0:
logger.exception( logger.exception(
f"{self.__str__()} discretisation requires the " f"y-direction spatial step to be greater than zero" f"{self.__str__()} discretisation requires the "
f"y-direction spatial step to be greater than zero"
) )
raise ValueError raise ValueError
if G.dl[2] <= 0: if G.dl[2] <= 0:
logger.exception( logger.exception(
f"{self.__str__()} discretisation requires the " f"z-direction spatial step to be greater than zero" f"{self.__str__()} discretisation requires the "
f"z-direction spatial step to be greater than zero"
) )
raise ValueError raise ValueError
@@ -188,7 +196,8 @@ class TimeStepStabilityFactor(UserObjectSingle):
if f <= 0 or f > 1: if f <= 0 or f > 1:
logger.exception( logger.exception(
f"{self.__str__()} requires the value of the time " f"step stability factor to be between zero and one" f"{self.__str__()} requires the value of the time "
f"step stability factor to be between zero and one"
) )
raise ValueError raise ValueError
@@ -261,7 +270,9 @@ class OMPThreads(UserObjectSingle):
) )
raise raise
if n < 1: if n < 1:
logger.exception(f"{self.__str__()} requires the value to be an " f"integer not less than one") logger.exception(
f"{self.__str__()} requires the value to be an " f"integer not less than one"
)
raise ValueError raise ValueError
config.get_model_config().ompthreads = set_omp_threads(n) config.get_model_config().ompthreads = set_omp_threads(n)
@@ -290,7 +301,9 @@ class PMLProps(UserObjectSingle):
G.pmls["formulation"] = self.kwargs["formulation"] G.pmls["formulation"] = self.kwargs["formulation"]
if G.pmls["formulation"] not in PML.formulations: if G.pmls["formulation"] not in PML.formulations:
logger.exception( logger.exception(
self.__str__() + f" requires the value to be " + f"one of {' '.join(PML.formulations)}" self.__str__()
+ f" requires the value to be "
+ f"one of {' '.join(PML.formulations)}"
) )
except KeyError: except KeyError:
pass pass