from __future__ import annotations
import os
from abc import ABC, abstractmethod
import xarray as xr
[docs]
class MeteoCalculator(ABC):
"""Base class for explicit meteorological preprocessing calculations."""
registry: dict[str, type["MeteoCalculator"]] = {}
variable_name: str | None = None
default_output_name: str | None = None
source_variable_groups: tuple[tuple[str, ...], ...] = ()
def __init_subclass__(cls, **kwargs) -> None:
"""Register concrete calculator subclasses by derived variable name."""
super().__init_subclass__(**kwargs)
if cls.variable_name:
MeteoCalculator.registry[cls.variable_name] = cls
[docs]
@classmethod
def available_derivations(cls) -> tuple[str, ...]:
"""Return the derived meteorological variables supported by the registry."""
return tuple(cls.registry)
[docs]
@classmethod
def get_calculator_class(cls, variable: str) -> type["MeteoCalculator"]:
"""Resolve the calculator class that is responsible for a derived variable."""
try:
return cls.registry[variable]
except KeyError as exc:
supported = ", ".join(cls.available_derivations())
raise ValueError(
f"Unsupported derived variable '{variable}'. Supported derived variables are: {supported}."
) from exc
[docs]
@classmethod
def create_variable_dataset(
cls,
inputs: str | list[str] | tuple[str, ...],
variable: str,
*,
source_variable: str | tuple[str, ...] | list[str] | None = None,
output_name: str | None = None,
**kwargs,
) -> xr.Dataset:
"""Build a derived-variable dataset from one or more NetCDF inputs."""
calculator_class = cls.get_calculator_class(variable)
source_dataarrays = calculator_class._load_required_sources(
inputs,
calculator_class._resolve_required_source_groups(source_variable),
)
target_name = output_name or calculator_class.default_output_name
if target_name is None:
raise ValueError(f"Calculator '{variable}' does not define a default output name.")
return calculator_class._build_output_dataset(
source_dataarrays,
output_name=target_name,
**kwargs,
)
[docs]
@classmethod
def create_variable_file(
cls,
inputs: str | list[str] | tuple[str, ...],
variable: str,
output_path: str | None = None,
*,
source_variable: str | tuple[str, ...] | list[str] | None = None,
output_name: str | None = None,
**kwargs,
) -> str:
"""Compute a derived variable and write it to a NetCDF file."""
ds = cls.create_variable_dataset(
inputs,
variable,
source_variable=source_variable,
output_name=output_name,
**kwargs,
)
destination = output_path or cls._default_output_path(inputs, ds, output_name)
ds.to_netcdf(destination)
return destination
@classmethod
def _resolve_required_source_groups(
cls,
source_variable: str | tuple[str, ...] | list[str] | None,
) -> tuple[tuple[str, ...], ...]:
"""Resolve source-variable overrides against the calculator requirements."""
if source_variable is None:
return cls.source_variable_groups
if isinstance(source_variable, str):
if len(cls.source_variable_groups) != 1:
raise ValueError(
"A single 'source_variable' override can only be used for "
"single-source derivations."
)
return ((source_variable,),)
override_values = tuple(source_variable)
if len(cls.source_variable_groups) == 1:
return (override_values,)
if len(override_values) != len(cls.source_variable_groups):
raise ValueError(
"The 'source_variable' override must provide one source name for each "
"required source input."
)
return tuple((name,) for name in override_values)
@classmethod
def _load_required_sources(
cls,
inputs: str | list[str] | tuple[str, ...],
required_sources: tuple[tuple[str, ...], ...],
) -> dict[str, xr.DataArray]:
"""Load the source data arrays required to compute a derived variable."""
loaded_sources: dict[str, xr.DataArray] = {}
for candidate_group in required_sources:
data_array = None
resolved_name = None
if isinstance(inputs, str) and os.path.isdir(inputs):
file_path, resolved_name = cls._find_variable_file(inputs, candidate_group)
with xr.open_dataset(file_path, engine="netcdf4") as ds:
data_array = ds[resolved_name].load()
elif isinstance(inputs, str):
with xr.open_dataset(inputs, engine="netcdf4") as ds:
resolved_name = cls._resolve_variable_from_candidates(ds, candidate_group)
data_array = ds[resolved_name].load()
else:
for file_path in inputs:
with xr.open_dataset(file_path, engine="netcdf4") as ds:
try:
resolved_name = cls._resolve_variable_from_candidates(ds, candidate_group)
except ValueError:
continue
data_array = ds[resolved_name].load()
break
if data_array is None or resolved_name is None:
requested = ", ".join(candidate_group)
raise ValueError(
f"Could not find a NetCDF input containing one of the source variables: {requested}."
)
loaded_sources[resolved_name] = data_array
return loaded_sources
@classmethod
def _default_output_path(
cls,
inputs: str | list[str] | tuple[str, ...],
ds: xr.Dataset,
output_name: str | None,
) -> str:
"""Infer a default NetCDF output path from the source input location."""
target_name = next(iter(ds.data_vars), output_name or cls.default_output_name or "output")
if isinstance(inputs, str) and os.path.isdir(inputs):
return os.path.join(inputs, f"{target_name}.nc")
if isinstance(inputs, str):
folder = os.path.dirname(inputs)
extension = os.path.splitext(inputs)[1] or ".nc"
return os.path.join(folder, f"{target_name}{extension}")
first_path = inputs[0]
folder = os.path.dirname(first_path)
extension = os.path.splitext(first_path)[1] or ".nc"
return os.path.join(folder, f"{target_name}{extension}")
@staticmethod
def _find_variable_file(
folder_path: str,
variable_names: tuple[str, ...],
) -> tuple[str, str]:
"""Find the first NetCDF file in a folder containing one of the requested variables."""
for file_name in sorted(os.listdir(folder_path)):
if not file_name.endswith(".nc"):
continue
file_path = os.path.join(folder_path, file_name)
with xr.open_dataset(file_path, engine="netcdf4") as ds:
try:
resolved_name = MeteoCalculator._resolve_variable_from_candidates(
ds,
variable_names,
)
except ValueError:
continue
return file_path, resolved_name
raise ValueError(
"Could not find a NetCDF file containing any of the variables "
f"{', '.join(variable_names)} in '{folder_path}'."
)
@staticmethod
def _resolve_variable_from_candidates(
ds: xr.Dataset,
variable_names: tuple[str, ...],
) -> str:
"""Resolve the first matching variable name found in a dataset."""
for variable_name in variable_names:
if variable_name in ds.data_vars:
return variable_name
available = ", ".join(ds.data_vars)
requested = ", ".join(variable_names)
raise ValueError(
f"Dataset does not contain any of the requested variables ({requested}). "
f"Available variables are: {available}."
)
@classmethod
@abstractmethod
def _build_output_dataset(
cls,
source_dataarrays: dict[str, xr.DataArray],
*,
output_name: str,
**kwargs,
) -> xr.Dataset:
"""Build the derived-variable dataset for a concrete calculator."""
raise NotImplementedError