Refactor Rx UserObject build process

这个提交包含在:
nmannall
2024-05-17 10:55:29 +01:00
父节点 888e33125c
当前提交 3db6ac0181
共有 2 个文件被更改,包括 6 次插入5 次删除

查看文件

@@ -79,7 +79,7 @@ class UserObjectMulti(ABC):
"""Creates object and adds it to model."""
pass
# TODO: Check if this is actually needed
# TODO: Make _do_rotate not use a grid object
def rotate(self, axis, angle, origin=None):
"""Rotates object (specialised for each object)."""
pass
@@ -942,13 +942,14 @@ class Rx(UserObjectMulti):
except KeyError:
pass
def build(self, grid, uip):
def build(self, model, uip):
try:
p1 = self.kwargs["p1"]
except KeyError:
logger.exception(self.params_str())
raise
grid = uip.grid
if self.do_rotate:
self._do_rotate(grid)
@@ -967,7 +968,7 @@ class Rx(UserObjectMulti):
r.ID = f"{r.__class__.__name__}({str(r.xcoord)},{str(r.ycoord)},{str(r.zcoord)})"
for key in RxUser.defaultoutputs:
r.outputs[key] = np.zeros(
grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]
model.iterations, dtype=config.sim_config.dtypes["float_or_double"]
)
else:
outputs.sort()
@@ -980,7 +981,7 @@ class Rx(UserObjectMulti):
for field in outputs:
if field in allowableoutputs:
r.outputs[field] = np.zeros(
grid.iterations, dtype=config.sim_config.dtypes["float_or_double"]
model.iterations, dtype=config.sim_config.dtypes["float_or_double"]
)
else:
logger.exception(

查看文件

@@ -29,7 +29,7 @@ class Rx:
allowableoutputs_dev = allowableoutputs[:-3]
def __init__(self):
self.ID = None
self.ID: str
self.outputs = {}
self.xcoord: int
self.ycoord: int