Move create_G() to Model class

这个提交包含在:
nmannall
2024-03-27 13:43:12 +00:00
父节点 0001deafff
当前提交 204ba261ad
共有 3 个文件被更改,包括 23 次插入23 次删除

查看文件

@@ -35,7 +35,7 @@ from gprMax.config import ModelConfig
from ._version import __version__, codename
from .model import Model
from .solvers import create_G, create_solver
from .solvers import create_solver
from .utilities.host_info import print_cuda_info, print_host_info, print_opencl_info
from .utilities.utilities import get_terminal_width, logo, timer
@@ -140,8 +140,7 @@ class Context:
return scene
def _create_model(self) -> Model:
grid = create_G()
return Model(grid)
return Model()
def print_logo_copyright(self) -> None:
"""Prints gprMax logo, version, and copyright/licencing information."""

查看文件

@@ -26,6 +26,9 @@ import numpy as np
import psutil
from colorama import Fore, Style, init
from gprMax.grid.cuda_grid import CUDAGrid
from gprMax.grid.opencl_grid import OpenCLGrid
init()
from terminaltables import SingleTable
@@ -36,7 +39,7 @@ import gprMax.config as config
from .cython.yee_cell_build import build_electric_components, build_magnetic_components
from .fields_outputs import write_hdf5_outputfile
from .geometry_outputs import save_geometry_views
from .grid.fdtd_grid import dispersion_analysis
from .grid.fdtd_grid import FDTDGrid, dispersion_analysis
from .materials import process_materials
from .pml import CFS, build_pml, print_pml_info
from .snapshots import save_snapshots
@@ -49,8 +52,8 @@ logger = logging.getLogger(__name__)
class Model:
"""Builds and runs (solves) a model."""
def __init__(self, G):
self.G = G
def __init__(self):
self.G = self._create_grid()
# Monitor memory usage
self.p = None
@@ -60,6 +63,21 @@ class Model:
# later for use with CPU solver.
config.get_model_config().ompthreads = set_omp_threads(config.get_model_config().ompthreads)
def _create_grid(self) -> FDTDGrid:
"""Create grid object according to solver.
Returns:
grid: FDTDGrid class describing a grid in a model.
"""
if config.sim_config.general["solver"] == "cpu":
grid = FDTDGrid()
elif config.sim_config.general["solver"] == "cuda":
grid = CUDAGrid()
elif config.sim_config.general["solver"] == "opencl":
grid = OpenCLGrid()
return grid
def build(self):
"""Builds the Yee cells for a model."""

查看文件

@@ -29,23 +29,6 @@ from .updates.opencl_updates import OpenCLUpdates
from .updates.updates import Updates
def create_G() -> FDTDGrid:
"""Create grid object according to solver.
Returns:
G: FDTDGrid class describing a grid in a model.
"""
if config.sim_config.general["solver"] == "cpu":
G = FDTDGrid()
elif config.sim_config.general["solver"] == "cuda":
G = CUDAGrid()
elif config.sim_config.general["solver"] == "opencl":
G = OpenCLGrid()
return G
class Solver:
"""Generic solver for Update objects"""