Docstrings and formatting tidying

这个提交包含在:
Craig Warren
2022-11-04 17:10:52 +00:00
父节点 69e3fe55c5
当前提交 1e177dc451
共有 7 个文件被更改,包括 120 次插入123 次删除

查看文件

@@ -1188,7 +1188,7 @@ class AddDebyeDispersion(UserObjectMulti):
class AddLorentzDispersion(UserObjectMulti):
"""Add dispersive properties to already defined :class:`Material` based
"""Adds dispersive properties to already defined :class:`Material` based
on multi-pole Lorentz formulation.
Attributes:

查看文件

@@ -58,10 +58,10 @@ class UserObjectSingle:
class Title(UserObjectSingle):
"""Allows you to include a title for your model.
"""Includes a title for your model.
:param name: Simulation title.
:type name: str, optional
Attributes:
name: string for model title.
"""
def __init__(self, **kwargs):
@@ -78,10 +78,10 @@ class Title(UserObjectSingle):
class Domain(UserObjectSingle):
"""Allows you to specify the size of the model.
"""Specifies the size of the model.
:param p1: point specifying total extend in x, y, z
:type p1: list of floats, non-optional
Attributes:
p1: tuple of floats specifying extent of model domain (x, y, z).
"""
def __init__(self, **kwargs):
@@ -99,7 +99,9 @@ class Domain(UserObjectSingle):
logger.exception(self.__str__() + ' requires at least one cell in every dimension')
raise ValueError
logger.info(f"Domain size: {self.kwargs['p1'][0]:g} x {self.kwargs['p1'][1]:g} x {self.kwargs['p1'][2]:g}m ({G.nx:d} x {G.ny:d} x {G.nz:d} = {(G.nx * G.ny * G.nz):g} cells)")
logger.info(f"Domain size: {self.kwargs['p1'][0]:g} x {self.kwargs['p1'][1]:g} x " +
f"{self.kwargs['p1'][2]:g}m ({G.nx:d} x {G.ny:d} x {G.nz:d} = " +
f"{(G.nx * G.ny * G.nz):g} cells)")
# Calculate time step at CFL limit; switch off appropriate PMLs for 2D
if G.nx == 1:
@@ -130,11 +132,10 @@ class Domain(UserObjectSingle):
class Discretisation(UserObjectSingle):
"""Allows you to specify the discretization of space in the x, y, and z
directions respectively.
"""Specifies the discretization of space in the x, y, and z directions.
:param p1: Specify discretisation in x, y, z direction
:type p1: list of floats, non-optional
Attributes:
p1: tuple of floats to specify spatial discretisation in x, y, z direction.
"""
def __init__(self, **kwargs):
@@ -150,25 +151,27 @@ class Discretisation(UserObjectSingle):
raise
if G.dl[0] <= 0:
logger.exception(self.__str__() + ' discretisation requires the x-direction spatial step to be greater than zero')
logger.exception(self.__str__() + ' discretisation requires the ' +
'x-direction spatial step to be greater than zero')
raise ValueError
if G.dl[1] <= 0:
logger.exception(self.__str__() + ' discretisation requires the y-direction spatial step to be greater than zero')
logger.exception(self.__str__() + ' discretisation requires the ' +
'y-direction spatial step to be greater than zero')
raise ValueError
if G.dl[2] <= 0:
logger.exception(self.__str__() + ' discretisation requires the z-direction spatial step to be greater than zero')
logger.exception(self.__str__() + ' discretisation requires the ' +
'z-direction spatial step to be greater than zero')
raise ValueError
logger.info(f'Spatial discretisation: {G.dl[0]:g} x {G.dl[1]:g} x {G.dl[2]:g}m')
class TimeWindow(UserObjectSingle):
"""Allows you to specify the total required simulated time
"""Specifies the total required simulated time.
:param time: Required simulated time in seconds
:type time: float, optional
:param iterations: Required number of iterations
:type iterations: int, optional
Attributes:
time: float of required simulated time in seconds.
iterations: int of required number of iterations.
"""
def __init__(self, **kwargs):
@@ -205,11 +208,11 @@ class TimeWindow(UserObjectSingle):
class OMPThreads(UserObjectSingle):
"""Allows you to control how many OpenMP threads (usually the number of
physical CPU cores available) are used when running the model.
"""Controls how many OpenMP threads (usually the number of physical CPU
cores available) are used when running the model.
:param n: Number of threads.
:type n: int, optional
Attributes:
n: int for number of threads.
"""
def __init__(self, **kwargs):
@@ -220,10 +223,12 @@ class OMPThreads(UserObjectSingle):
try:
n = self.kwargs['n']
except KeyError:
logger.exception(self.__str__() + ' requires exactly one parameter to specify the number of CPU OpenMP threads to use')
logger.exception(self.__str__() + ' requires exactly one parameter ' +
'to specify the number of CPU OpenMP threads to use')
raise
if n < 1:
logger.exception(self.__str__() + ' requires the value to be an integer not less than one')
logger.exception(self.__str__() + ' requires the value to be an ' +
'integer not less than one')
raise ValueError
config.get_model_config().ompthreads = set_omp_threads(n)
@@ -232,8 +237,8 @@ class OMPThreads(UserObjectSingle):
class TimeStepStabilityFactor(UserObjectSingle):
"""Factor by which to reduce the time step from the CFL limit.
:param f: Factor to multiple time step.
:type f: float, optional
Attributes:
f: float for factor to multiple time step.
"""
def __init__(self, **kwargs):
@@ -248,7 +253,8 @@ class TimeStepStabilityFactor(UserObjectSingle):
raise
if f <= 0 or f > 1:
logger.exception(self.__str__() + ' requires the value of the time step stability factor to be between zero and one')
logger.exception(self.__str__() + ' requires the value of the time ' +
'step stability factor to be between zero and one')
raise ValueError
G.dt = G.dt * f
@@ -256,7 +262,7 @@ class TimeStepStabilityFactor(UserObjectSingle):
class PMLFormulation(UserObjectSingle):
"""Allows you to specify the formulation (type) of the PML to be used.
"""Specifies the formulation (type) of the PML to be used.
Attributes:
pml: string specifying formulation of PML.
@@ -270,34 +276,26 @@ class PMLFormulation(UserObjectSingle):
try:
pml = self.kwargs['pml']
except KeyError:
logger.exception(self.__str__() + ' requires exactly one parameter to specify the formulation of PML to use')
logger.exception(self.__str__() + ' requires exactly one parameter ' +
'to specify the formulation of PML to use')
raise
if pml not in PML.formulations:
logger.exception(self.__str__() + f" requires the value to be one of {' '.join(PML.formulations)}")
logger.exception(self.__str__() + f" requires the value to be one " +
f"of {' '.join(PML.formulations)}")
raise ValueError
G.pmlformulation = pml
class PMLCells(UserObjectSingle):
"""Allows you to control the number of cells (thickness) of PML that are used
on the six sides of the model domain. Specify either single thickness or
thickness on each side.
"""Controls the number of cells (thickness) of PML that are used on the six
sides of the model domain. Specify either single thickness or thickness
on each side.
:param thickness: Thickness of PML on all 6 sides.
:type thickness: int, optional
:param x0: Thickness of PML on left side.
:type x0: int, optional
:param y0: Thickness of PML on the front side.
:type y0: int, optional
:param z0: Thickness of PML on bottom side.
:type z0: int, optional
:param xmax: Thickness of PML on right side.
:type xmax: int, optional
:param ymax: Thickness of PML on the back side.
:type ymax: int, optional
:param zmax: Thickness of PML on top side.
:type zmax: int, optional
Attributes:
thickness: int for thickness of PML on all 6 sides.
x0, y0, z0, xmax, ymax, zmax: ints of thickness of PML on individual
sides of the model domain.
"""
def __init__(self, **kwargs):
@@ -333,11 +331,10 @@ class PMLCells(UserObjectSingle):
class SrcSteps(UserObjectSingle):
"""Provides a simple method to allow you to move the location of all simple
sources.
"""Moves the location of all simple sources.
:param p1: increments (x,y,z) to move all simple sources
:type p1: list, non-optional
Attributes:
p1: tuple of float increments (x,y,z) to move all simple sources.
"""
def __init__(self, **kwargs):
@@ -351,15 +348,16 @@ class SrcSteps(UserObjectSingle):
logger.exception(self.__str__() + ' requires exactly three parameters')
raise
logger.info(f'Simple sources will step {G.srcsteps[0] * G.dx:g}m, {G.srcsteps[1] * G.dy:g}m, {G.srcsteps[2] * G.dz:g}m for each model run.')
logger.info(f'Simple sources will step {G.srcsteps[0] * G.dx:g}m, ' +
f'{G.srcsteps[1] * G.dy:g}m, {G.srcsteps[2] * G.dz:g}m ' +
f'for each model run.')
class RxSteps(UserObjectSingle):
"""Provides a simple method to allow you to move the location of all simple
receivers.
"""Moves the location of all receivers.
:param p1: increments (x,y,z) to move all simple receivers
:type p1: list, non-optional
Attributes:
p1: tuple of float increments (x,y,z) to move all receivers.
"""
def __init__(self, **kwargs):
@@ -373,20 +371,20 @@ class RxSteps(UserObjectSingle):
logger.exception(self.__str__() + ' requires exactly three parameters')
raise
logger.info(f'All receivers will step {G.rxsteps[0] * G.dx:g}m, {G.rxsteps[1] * G.dy:g}m, {G.rxsteps[2] * G.dz:g}m for each model run.')
logger.info(f'All receivers will step {G.rxsteps[0] * G.dx:g}m, ' +
f'{G.rxsteps[1] * G.dy:g}m, {G.rxsteps[2] * G.dz:g}m ' +
f'for each model run.')
class ExcitationFile(UserObjectSingle):
"""Allows you to specify an ASCII file that contains columns of amplitude
values that specify custom waveform shapes that can be used with sources
in the model.
"""An ASCII file that contains columns of amplitude values that specify
custom waveform shapes that can be used with sources in the model.
:param filepath: Excitation file path.
:type filepath: str, non-optional
:param kind: passed to the interpolation function (scipy.interpolate.interp1d).
:type kind: float, optional
:param fill_value: passed to the interpolation function (scipy.interpolate.interp1d).
:type fill_value: float, optional
Attributes:
filepath: string of excitation file path.
kind: string or int specifying interpolation kind passed to
scipy.interpolate.interp1d.
fill_value: float or 'extrapolate' passed to scipy.interpolate.interp1d.
"""
def __init__(self, **kwargs):
@@ -413,7 +411,8 @@ class ExcitationFile(UserObjectSingle):
excitationfile = Path(excitationfile)
# excitationfile = excitationfile.resolve()
if not excitationfile.exists():
excitationfile = Path(config.sim_config.input_file_path.parent, excitationfile)
excitationfile = Path(config.sim_config.input_file_path.parent,
excitationfile)
logger.info(f'Excitation file: {excitationfile}')
@@ -422,7 +421,8 @@ class ExcitationFile(UserObjectSingle):
waveformIDs = f.readline().split()
# Read all waveform values into an array
waveformvalues = np.loadtxt(excitationfile, skiprows=1, dtype=config.sim_config.dtypes['float_or_double'])
waveformvalues = np.loadtxt(excitationfile, skiprows=1,
dtype=config.sim_config.dtypes['float_or_double'])
# Time array (if specified) for interpolation, otherwise use simulation time
if waveformIDs[0].lower() == 'time':
@@ -458,16 +458,18 @@ class ExcitationFile(UserObjectSingle):
# Interpolate waveform values
w.userfunc = interpolate.interp1d(waveformtime, singlewaveformvalues, **kwargs)
logger.info(f"User waveform {w.ID} created using {timestr} and, if required, interpolation parameters (kind: {kwargs['kind']}, fill value: {kwargs['fill_value']}).")
logger.info(f"User waveform {w.ID} created using {timestr} and, if " +
f"required, interpolation parameters (kind: {kwargs['kind']}, " +
f"fill value: {kwargs['fill_value']}).")
G.waveforms.append(w)
class OutputDir(UserObjectSingle):
"""Allows you to control the directory where output file(s) will be stored.
"""Controls the directory where output file(s) will be stored.
:param dir: File path to directory.
:type dir: str, non-optional
Attributes:
dir: string of file path to directory.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -481,8 +483,8 @@ class NumberOfModelRuns(UserObjectSingle):
"""Number of times to run the simulation. This required when using multiple
class:Scene instances.
:param n: File path to directory.
:type n: str, non-optional
Attributes:
n: int of number of model runs.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)

查看文件

@@ -141,7 +141,7 @@ class ModelConfig:
def set_dispersive_material_types(self):
"""Set data type for disperive materials. Complex if Drude or Lorentz
"""Sets data type for disperive materials. Complex if Drude or Lorentz
materials are present. Real if Debye materials.
"""
if self.materials['drudelorentz']:
@@ -154,9 +154,9 @@ class ModelConfig:
self.materials['dispersiveCdtype'] = sim_config.dtypes['C_float_or_double']
def set_output_file_path(self, outputdir=None):
"""Output file path can be provided by the user via the API or an input
file command. If they haven't provided one use the input file path
instead.
"""Sets output file path. Can be provided by the user via the API or an
input file command. If they haven't provided one use the input file
path instead.
Args:
outputdir: string of output file directory given by input file command.
@@ -179,7 +179,7 @@ class ModelConfig:
self.output_file_path_ext = self.output_file_path.with_suffix('.h5')
def set_snapshots_dir(self):
"""Set directory to store any snapshots.
"""Sets directory to store any snapshots.
Returns:
snapshot_dir: Path to directory to store snapshot files in.
@@ -345,13 +345,13 @@ class SimulationConfig:
self.dtypes['C_complex'] = 'cdouble'
def _get_byteorder(self):
"""Check the byte order of system to use for VTK files, i.e. geometry
"""Checks the byte order of system to use for VTK files, i.e. geometry
views and snapshots.
"""
self.vtk_byteorder = 'LittleEndian' if sys.byteorder == 'little' else 'BigEndian'
def _set_model_start_end(self):
"""Set range for number of models to run (internally 0 index)."""
"""Sets range for number of models to run (internally 0 index)."""
if self.args.i:
modelstart = self.args.i - 1
modelend = modelstart + self.args.n
@@ -363,7 +363,7 @@ class SimulationConfig:
self.model_end = modelend
def _set_input_file_path(self):
"""Set input file path for CLI or API."""
"""Sets input file path for CLI or API."""
# API
if self.args.inputfile is None:
self.input_file_path = Path(self.args.outputfile)

查看文件

@@ -81,12 +81,12 @@ class Context:
self.print_time_report()
def print_logo_copyright(self):
"""Print gprMax logo, version, and copyright/licencing information."""
"""Prints gprMax logo, version, and copyright/licencing information."""
logo_copyright = logo(__version__ + ' (' + codename + ')')
logger.basic(logo_copyright)
def print_time_report(self):
"""Print the total simulation time based on context."""
"""Prints the total simulation time based on context."""
s = ("\n=== Simulation completed in [HH:MM:SS]: "
f"{datetime.timedelta(seconds=self.tsimend - self.tsimstart)}")
logger.basic(f"{s} {'=' * (get_terminal_width() - 1 - len(s))}\n")
@@ -112,9 +112,8 @@ class MPIContext(Context):
"""Process for running a single model.
Args:
work (dict): contains any additional information that is passed to
MPI workers. By default only model number (i) is
used.
work: dict of any additional information that is passed to MPI
workers. By default only model number (i) is used.
"""
# Create configuration for model

查看文件

@@ -31,10 +31,7 @@ def store_outputs(G):
"""Stores field component values for every receiver and transmission line.
Args:
iteration (int): Current iteration number.
Ex, Ey, Ez, Hx, Hy, Hz (memory view): Current electric and magnetic
field values.
G (FDTDGrid): Parameters describing a grid in a model.
G: FDTDGrid class describing a grid in a model.
"""
iteration = G.iteration
@@ -99,11 +96,11 @@ __global__ void store_outputs(int NRX, int iteration, const int* __restrict__ rx
def write_hdf5_outputfile(outputfile, G):
"""Write an output file in HDF5 (.h5) format.
"""Writes an output file in HDF5 (.h5) format.
Args:
outputfile (str): Name of the output file.
G (FDTDGrid): Parameters describing a grid in a model.
outputfile: string of the name of the output file.
G: FDTDGrid class describing a grid in a model.
"""
# Check for any receivers in subgrids
@@ -130,12 +127,12 @@ def write_hdf5_outputfile(outputfile, G):
def write_grid(basegrp, G, is_subgrid=False):
"""Write grid meta data and data to HDF5 group.
"""Writes grid meta data and data to HDF5 group.
Args:
basegrp (dict): HDF5 group.
G (FDTDGrid): Parameters describing a grid in a model.
is_subgrid (bool): Is grid instance the main grid or a subgrid.
basegrp: dict of HDF5 group.
G: FDTDGrid class describing a grid in a model.
is_subgrid: boolean for grid instance the main grid or a subgrid.
"""
# Write meta data for grid

查看文件

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
def get_host_info():
"""Get information about the machine, CPU, RAM, and OS.
"""Gets information about the machine, CPU, RAM, and OS.
Returns:
hostinfo: dict containing manufacturer and model of machine;
@@ -199,7 +199,7 @@ def get_host_info():
def print_host_info(hostinfo):
"""Print information about the machine, CPU, RAM, and OS.
"""Prints information about the machine, CPU, RAM, and OS.
Args:
hostinfo: dict containing manufacturer and model of machine;
@@ -266,7 +266,7 @@ def set_omp_threads(nthreads=None):
def mem_check_host(mem):
"""Check if the required amount of memory (RAM) is available on host.
"""Checks if the required amount of memory (RAM) is available on host.
Args:
mem: int for memory required (bytes).
@@ -279,7 +279,7 @@ def mem_check_host(mem):
def mem_check_device_snaps(total_mem, snaps_mem):
"""Check if the required amount of memory (RAM) for all snapshots can fit
"""Checks if the required amount of memory (RAM) for all snapshots can fit
on specified device.
Args:
@@ -305,7 +305,7 @@ def mem_check_device_snaps(total_mem, snaps_mem):
def mem_check_all(grids):
"""Check memory for all grids, including for any dispersive materials,
"""Checks memory for all grids, including for any dispersive materials,
snapshots, and if solver with GPU, whether snapshots will fit on GPU
memory.
@@ -356,7 +356,7 @@ def mem_check_all(grids):
def has_pycuda():
"""Check if pycuda module is installed."""
"""Checks if pycuda module is installed."""
pycuda = True
try:
import pycuda
@@ -366,7 +366,7 @@ def has_pycuda():
def has_pyopencl():
"""Check if pyopencl module is installed."""
"""Checks if pyopencl module is installed."""
pyopencl = True
try:
import pyopencl
@@ -376,7 +376,7 @@ def has_pyopencl():
def detect_cuda_gpus():
"""Get information about CUDA-capable GPU(s).
"""Gets information about CUDA-capable GPU(s).
Returns:
gpus: dict of detected pycuda device object(s) where where device ID(s)
@@ -415,7 +415,7 @@ def detect_cuda_gpus():
def print_cuda_info(devs):
""""Print info about detected CUDA-capable GPU(s).
""""Prints info about detected CUDA-capable GPU(s).
Args:
devs: dict of detected pycuda device object(s) where where device ID(s)
@@ -433,7 +433,7 @@ def print_cuda_info(devs):
def detect_opencl():
"""Get information about OpenCL platforms and devices.
"""Gets information about OpenCL platforms and devices.
Returns:
devs: dict of detected pyopencl device object(s) where where device ID(s)
@@ -465,7 +465,7 @@ def detect_opencl():
def print_opencl_info(devs):
""""Print info about detected OpenCL-capable device(s).
""""Prints info about detected OpenCL-capable device(s).
Args:
devs: dict of detected pyopencl device object(s) where where device ID(s)

查看文件

@@ -19,10 +19,8 @@
import datetime
import decimal as d
import logging
import os
import re
import textwrap
import xml.dom.minidom
from shutil import get_terminal_size
import numpy as np
@@ -41,10 +39,10 @@ except ImportError:
def get_terminal_width():
"""Get/set width of terminal being used.
"""Gets/sets width of terminal being used.
Returns:
terminalwidth: an int for the terminal width.
terminalwidth: int for the terminal width.
"""
terminalwidth = get_terminal_size()[0]
@@ -55,7 +53,7 @@ def get_terminal_width():
def logo(version):
"""Print gprMax logo, version, and licencing/copyright information.
"""Prints gprMax logo, version, and licencing/copyright information.
Args:
version: string for version number.
@@ -141,7 +139,7 @@ def round32(value):
def fft_power(waveform, dt):
"""Calculate a FFT of the given waveform of amplitude values;
"""Calculates FFT of the given waveform of amplitude values;
converted to decibels and shifted so that maximum power is 0dB
Args:
@@ -171,18 +169,19 @@ def fft_power(waveform, dt):
def human_size(size, a_kilobyte_is_1024_bytes=False):
"""Convert a file size to human-readable form.
"""Converts a file size to human-readable form.
Args:
size: int for file size in bytes.
a_kilobyte_is_1024_bytes: bool - true for multiples of 1024,
a_kilobyte_is_1024_bytes: boolean - true for multiples of 1024,
or false for multiples of 1000.
Returns:
Human-readable string of size.
"""
suffixes = {1000: ['KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'], 1024: ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']}
suffixes = {1000: ['KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'],
1024: ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']}
if size < 0:
raise ValueError('Number must be non-negative.')
@@ -207,15 +206,15 @@ def natural_keys(text):
def timer():
"""Function to return time in fractional seconds."""
"""Time in fractional seconds."""
return timer_fn()
def numeric_list_to_int_list(l):
"""Return a list of int from a numerical list."""
"""List of int from a numerical list."""
return list(map(int, l))
def numeric_list_to_float_list(l):
"""Return a list of float from a numerical list."""
"""List of float from a numerical list."""
return list(map(float, l))