Refactor Rx UserObject build process

这个提交包含在:
nmannall
2024-05-17 10:55:29 +01:00
父节点 888e33125c
当前提交 3db6ac0181
共有 2 个文件被更改,包括 6 次插入5 次删除

查看文件

@@ -79,7 +79,7 @@ class UserObjectMulti(ABC):
"""Creates object and adds it to model.""" """Creates object and adds it to model."""
pass pass
# TODO: Check if this is actually needed # TODO: Make _do_rotate not use a grid object
def rotate(self, axis, angle, origin=None): def rotate(self, axis, angle, origin=None):
"""Rotates object (specialised for each object).""" """Rotates object (specialised for each object)."""
pass pass
@@ -942,13 +942,14 @@ class Rx(UserObjectMulti):
except KeyError: except KeyError:
pass pass
def build(self, grid, uip): def build(self, model, uip):
try: try:
p1 = self.kwargs["p1"] p1 = self.kwargs["p1"]
except KeyError: except KeyError:
logger.exception(self.params_str()) logger.exception(self.params_str())
raise raise
grid = uip.grid
if self.do_rotate: if self.do_rotate:
self._do_rotate(grid) self._do_rotate(grid)
@@ -967,7 +968,7 @@ class Rx(UserObjectMulti):
r.ID = f"{r.__class__.__name__}({str(r.xcoord)},{str(r.ycoord)},{str(r.zcoord)})" r.ID = f"{r.__class__.__name__}({str(r.xcoord)},{str(r.ycoord)},{str(r.zcoord)})"
for key in RxUser.defaultoutputs: for key in RxUser.defaultoutputs:
r.outputs[key] = np.zeros( r.outputs[key] = np.zeros(
grid.iterations, dtype=config.sim_config.dtypes["float_or_double"] model.iterations, dtype=config.sim_config.dtypes["float_or_double"]
) )
else: else:
outputs.sort() outputs.sort()
@@ -980,7 +981,7 @@ class Rx(UserObjectMulti):
for field in outputs: for field in outputs:
if field in allowableoutputs: if field in allowableoutputs:
r.outputs[field] = np.zeros( r.outputs[field] = np.zeros(
grid.iterations, dtype=config.sim_config.dtypes["float_or_double"] model.iterations, dtype=config.sim_config.dtypes["float_or_double"]
) )
else: else:
logger.exception( logger.exception(

查看文件

@@ -29,7 +29,7 @@ class Rx:
allowableoutputs_dev = allowableoutputs[:-3] allowableoutputs_dev = allowableoutputs[:-3]
def __init__(self): def __init__(self):
self.ID = None self.ID: str
self.outputs = {} self.outputs = {}
self.xcoord: int self.xcoord: int
self.ycoord: int self.ycoord: int