from __future__ import annotations
import os
from dataclasses import dataclass
import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from osgeo import gdal, ogr
from shapely.geometry import Point
from ..core import UnitHandler, projections, utils as u
from ..core.netcdf import fix_calendar
from ..physiographic.base import Basin
from .base import Meteo
from .schema import DEFAULT_METEO_SCHEMA, MeteoSchema
__methods__ = [
"linear",
"nearest",
]
[docs]
@dataclass(frozen=True)
class NetCDFGridConfig:
"""Configuration for the NetCDF-to-CEQUEAU workflow."""
time_name: str = "time"
lat_name: str = "lat"
lon_name: str = "lon"
ce_index_offset: int = 10
[docs]
class NetCDFMeteo(Meteo):
"""Meteorological workflow for gridded NetCDF datasets."""
def __init__(
self,
basin_struct: Basin,
ds: xr.Dataset,
config: NetCDFGridConfig | None = None,
schema: MeteoSchema | None = None,
) -> None:
"""Create a gridded meteorological workflow from a NetCDF dataset."""
self.config = config or NetCDFGridConfig()
self.schema = schema or DEFAULT_METEO_SCHEMA
prepared_ds = self.prepare_dataset(ds, export_names=True)
prepared_ds = _standardize_dataset(prepared_ds, self.config)
super().__init__(basin_struct, prepared_ds, schema=self.schema)
self.table: pd.DataFrame | None = None
self.lon_utm: np.ndarray | None = None
self.lat_utm: np.ndarray | None = None
self.interpolated: xr.Dataset | None = None
ce_area = float(self.basin_struct.bassinVersant["superficieCE"]) * 1e6
self.basin_struct.set_dimensions(np.sqrt(ce_area), np.sqrt(ce_area))
[docs]
@classmethod
def from_dataset(
cls,
basin_struct: Basin,
ds: xr.Dataset,
config: NetCDFGridConfig | None = None,
schema: MeteoSchema | None = None,
) -> "NetCDFMeteo":
"""Build a :class:`NetCDFMeteo` instance from an in-memory dataset."""
return cls(basin_struct, ds, config=config, schema=schema)
[docs]
@classmethod
def load_from_netcdf(
cls,
basin_struct: Basin,
vars_path: str,
config: NetCDFGridConfig | None = None,
schema: MeteoSchema | None = None,
) -> "NetCDFMeteo":
"""Load prepared meteorological NetCDF files from a folder."""
meteo_schema = schema or DEFAULT_METEO_SCHEMA
ds = cls.load_netcdf_dataset(vars_path, schema=meteo_schema, export_names=True)
return cls(basin_struct, ds, config=config, schema=meteo_schema)
[docs]
@classmethod
def prepare_dataset(
cls,
ds: xr.Dataset,
*,
schema: MeteoSchema | None = None,
file_label: str | None = None,
export_names: bool = False,
) -> xr.Dataset:
"""Normalize a meteorological dataset to the internal NetCDFMeteo contract."""
meteo_schema = schema or DEFAULT_METEO_SCHEMA
ds = _standardize_input_dataset(ds)
prepared_vars: dict[str, xr.DataArray] = {}
for variable_name in ds.data_vars:
spec = meteo_schema.get_variable_spec(variable_name)
converted = UnitHandler.convert_dataarray_to_canonical_units(
ds[variable_name],
spec,
)
canonical_name = spec.canonical_name
if canonical_name in prepared_vars:
raise ValueError(
f"Dataset contains duplicate variables that normalize to '{canonical_name}'."
)
prepared_vars[canonical_name] = converted.rename(canonical_name)
prepared = xr.Dataset(prepared_vars, coords=ds.coords, attrs=dict(ds.attrs))
_validate_daily_time_axis(prepared, file_label=file_label)
if export_names:
prepared = _export_to_cequeau_names(prepared, meteo_schema)
return prepared
[docs]
@classmethod
def load_netcdf_dataset(
cls,
vars_path: str,
*,
schema: MeteoSchema | None = None,
export_names: bool = True,
) -> xr.Dataset:
"""Load and merge all supported NetCDF meteorological files in a folder."""
meteo_schema = schema or DEFAULT_METEO_SCHEMA
datasets: list[xr.Dataset] = []
for file_name in sorted(os.listdir(vars_path)):
if not file_name.endswith(".nc"):
continue
file_path = os.path.join(vars_path, file_name)
ds = xr.open_dataset(file_path, engine="netcdf4")
prepared = cls.prepare_dataset(
ds,
schema=meteo_schema,
file_label=file_name,
export_names=False,
)
datasets.append(prepared)
if not datasets:
raise ValueError(f"No NetCDF files were found in '{vars_path}'.")
_validate_aligned_time_axes(datasets)
merged = xr.merge(datasets, join="exact", compat="override")
if export_names:
merged = _export_to_cequeau_names(merged, meteo_schema)
return merged
[docs]
def stations_table(
self,
name_meteo_grid_file: str = "meteo_grid_points.shp",
export: bool = True,
) -> pd.DataFrame:
"""Build the station table used for raster interpolation."""
dem = gdal.Open(self.basin_struct.DEM, gdal.GA_ReadOnly)
ce_grid = gdal.Open(self.basin_struct.get_CEgrid, gdal.GA_ReadOnly)
watershed = ogr.Open(self.basin_struct.watershed_shapefile, gdal.GA_ReadOnly)
epsg_dem = projections.get_proj_code(dem)
xy_pair = _get_netcdf_grid_points(self.ds, ce_grid, watershed, self.config)
lon_utm, lat_utm = projections.latlon_to_utm(
xy_pair[:, 1],
xy_pair[:, 0],
epsg_dem,
)
self.lon_utm = lon_utm
self.lat_utm = lat_utm
self.table = _create_station_table(
ce_grid,
dem,
lon_utm,
lat_utm,
xy_pair,
self.config,
)
if export:
self._export_station_table(name_meteo_grid_file, epsg_dem)
return self.table
[docs]
def interpolation(
self,
method: str = "nearest",
name_meteo_grid_file: str = "meteo_grid_points.shp",
) -> xr.Dataset:
"""Interpolate the input dataset over the CE grid."""
if method not in __methods__:
raise ValueError(
f"Unsupported interpolation method '{method}'. Expected one of {__methods__}."
)
if self.table is None:
self.stations_table(name_meteo_grid_file=name_meteo_grid_file)
self.interpolated = _interpolate_dataset_to_ce_grid(
self.ds,
self.basin_struct.get_CEgrid,
self.table,
method,
self.config,
)
return self.interpolated
[docs]
def to_cequeau_grid(self, ds: xr.Dataset | None = None) -> xr.Dataset:
"""Export an interpolated dataset to the CEQUEAU meteorological layout."""
source = ds if ds is not None else self.interpolated
if source is None:
raise ValueError("No interpolated dataset is available to convert.")
return self.cequeau_grid(source, self.basin_struct)
@classmethod
def _cequeau_grid(cls, ds: xr.Dataset, basin_struct: Basin) -> xr.Dataset:
"""Convert an interpolated gridded dataset to the CEQUEAU grid layout."""
ce_df = _load_ce_table(basin_struct)
return _build_cequeau_grid(ds, ce_df)
def _export_station_table(self, file_name: str, epsg_code: int) -> None:
if self.table is None:
raise ValueError("The station table must be created before exporting it.")
points = [
Point(x, y)
for y, x in zip(self.table["lat_utm"], self.table["lon_utm"])
]
gdf = gpd.GeoDataFrame(self.table.copy(), geometry=points, crs=epsg_code)
output_path = os.path.join(
self.basin_struct.project_path,
"geographic",
file_name,
)
gdf.to_file(output_path)
def _standardize_input_dataset(ds: xr.Dataset) -> xr.Dataset:
"""Normalize coordinate names, calendar handling, and coordinate ordering."""
rename_map: dict[str, str] = {}
if "longitude" in ds.coords and "lon" not in ds.coords:
rename_map["longitude"] = "lon"
if "latitude" in ds.coords and "lat" not in ds.coords:
rename_map["latitude"] = "lat"
if "valid_time" in ds.dims and "time" not in ds.dims:
rename_map["valid_time"] = "time"
elif "valid_time" in ds.coords and "time" not in ds.coords and ds["valid_time"].ndim == 1:
time_dim = ds["valid_time"].dims[0]
if time_dim == "time":
ds = ds.assign_coords(time=ds["valid_time"])
ds = ds.drop_vars("valid_time")
else:
ds = ds.swap_dims({time_dim: "valid_time"})
rename_map["valid_time"] = "time"
if rename_map:
ds = ds.rename(rename_map)
required_coords = {"time", "lat", "lon"}
missing = required_coords.difference(ds.coords)
if missing:
missing_names = ", ".join(sorted(missing))
raise ValueError(f"Missing NetCDF coordinates: {missing_names}")
if "time" in ds.coords and hasattr(ds["time"], "dt"):
if not np.issubdtype(ds["time"].dtype, np.datetime64):
calendar = getattr(ds["time"].dt, "calendar", "gregorian")
if calendar != "gregorian":
ds = fix_calendar(ds)
ds = ds.sortby("time")
ds = ds.sortby("lat")
ds = ds.sortby("lon")
if not np.issubdtype(ds["time"].dtype, np.datetime64):
ds["time"] = xr.DataArray(ds.indexes["time"].to_datetimeindex(), dims=("time",))
return ds
def _validate_daily_time_axis(ds: xr.Dataset, file_label: str | None = None) -> None:
"""Check that the dataset uses a strictly increasing daily time axis."""
time_values = ds["time"].values
if len(time_values) < 2:
return
diffs = np.diff(time_values).astype("timedelta64[D]")
if np.any(diffs <= np.timedelta64(0, "D")):
prefix = f" in '{file_label}'" if file_label else ""
raise ValueError(f"Time coordinate must be strictly increasing{prefix}.")
expected_step = np.timedelta64(1, "D")
if np.any(diffs != expected_step):
prefix = f" in '{file_label}'" if file_label else ""
raise ValueError(f"Expected a daily time axis with 1-day increments{prefix}.")
def _validate_aligned_time_axes(datasets: list[xr.Dataset]) -> None:
"""Check that all meteorological datasets share the same time coordinate."""
if not datasets:
return
reference = datasets[0]["time"].values
for dataset in datasets[1:]:
if not np.array_equal(reference, dataset["time"].values):
raise ValueError(
"Meteorological NetCDF files do not share the same time coordinate. "
"Preprocess them so all variables use the same daily dates."
)
def _export_to_cequeau_names(ds: xr.Dataset, schema: MeteoSchema) -> xr.Dataset:
"""Rename canonical internal variable names to their CEQUEAU-facing names."""
rename_map: dict[str, str] = {}
for variable_name in ds.data_vars:
export_name = schema.get_export_name(variable_name)
if variable_name != export_name:
rename_map[variable_name] = export_name
if rename_map:
ds = ds.rename(rename_map)
return ds
def _standardize_dataset(ds: xr.Dataset, config: NetCDFGridConfig) -> xr.Dataset:
"""Normalize coordinate names and ordering for interpolation operations."""
rename_map: dict[str, str] = {}
if "longitude" in ds.coords and config.lon_name not in ds.coords:
rename_map["longitude"] = config.lon_name
if "latitude" in ds.coords and config.lat_name not in ds.coords:
rename_map["latitude"] = config.lat_name
if rename_map:
ds = ds.rename(rename_map)
required_coords = {config.time_name, config.lat_name, config.lon_name}
missing = required_coords.difference(ds.coords)
if missing:
missing_names = ", ".join(sorted(missing))
raise ValueError(f"Missing NetCDF coordinates: {missing_names}")
ds = ds.sortby(config.time_name)
ds = ds.sortby(config.lat_name)
ds = ds.sortby(config.lon_name)
return ds
def _load_ce_table(basin_struct: Basin) -> pd.DataFrame:
"""Load the CE grid index table used to reshape interpolated outputs."""
ce_path = os.path.join(
basin_struct.project_path,
"results",
"carreauxEntiers.csv",
)
ce_df = pd.read_csv(ce_path, index_col=0)
ce_df.index = ce_df["CEid"].astype(int).values
ce_df["i"] = ce_df["i"].astype(int)
ce_df["j"] = ce_df["j"].astype(int)
return ce_df
def _build_cequeau_grid(ds: xr.Dataset, ce_df: pd.DataFrame) -> xr.Dataset:
"""Build the final CEQUEAU meteorological dataset from interpolated fields."""
variable_names = [name for name in ds.data_vars if name != "CE"]
if not variable_names:
raise ValueError("The interpolated dataset does not contain meteorological variables.")
time_values = pd.to_datetime(ds["time"].values)
datenum = np.array(
[366.0 + timestamp.toordinal() for timestamp in time_values],
dtype=np.float32,
)
result = xr.Dataset(coords={
"CEid": ce_df["CEid"].values.astype(np.int32),
"pasTemp": datenum,
})
for variable_name in variable_names:
result[variable_name] = (
("pasTemp", "CEid"),
_extract_ce_timeseries(ds, ce_df, variable_name),
)
result[variable_name].attrs = ds[variable_name].attrs
return result
def _extract_ce_timeseries(
ds: xr.Dataset,
ce_df: pd.DataFrame,
variable_name: str,
) -> np.ndarray:
"""Extract one time series per CE cell for a given meteorological variable."""
ce_data = np.zeros((ds.sizes["time"], len(ce_df)), dtype=np.float32)
for position, (_, ce_row) in enumerate(ce_df.iterrows()):
i_index = int(ce_row["i"])
j_index = int(ce_row["j"])
ce_data[:, position] = ds[variable_name].loc[:, j_index, i_index].values.astype(
np.float32
)
return ce_data
def _interpolate_dataset_to_ce_grid(
ds: xr.Dataset,
ce_grid_path: str,
table: pd.DataFrame,
method: str,
config: NetCDFGridConfig,
) -> xr.Dataset:
"""Interpolate a gridded meteorological dataset onto the CE raster grid."""
ce_grid = gdal.Open(ce_grid_path, gdal.GA_ReadOnly)
ce_array = np.array(ce_grid.GetRasterBand(1).ReadAsArray())
row_count, col_count = ce_array.shape
i = np.arange(col_count, dtype=np.int16) + config.ce_index_offset
j = np.arange(row_count, dtype=np.int16) + config.ce_index_offset
working_ds = ds.copy()
if float(working_ds[config.lon_name].max()) > 180:
working_ds = working_ds.assign_coords(
{config.lon_name: working_ds[config.lon_name].values - 360}
)
working_ds = working_ds.sortby(config.lon_name)
lon_max_idx = u.find_nearest(working_ds[config.lon_name].values, table["lon"].max())
lon_min_idx = u.find_nearest(working_ds[config.lon_name].values, table["lon"].min())
lat_max_idx = u.find_nearest(working_ds[config.lat_name].values, table["lat"].max())
lat_min_idx = u.find_nearest(working_ds[config.lat_name].values, table["lat"].min())
working_ds = working_ds.isel(
{
config.lat_name: slice(min(lat_min_idx, lat_max_idx), max(lat_min_idx, lat_max_idx)),
config.lon_name: slice(min(lon_min_idx, lon_max_idx), max(lon_min_idx, lon_max_idx)),
}
)
working_ds = working_ds.assign_coords(
{
config.lon_name: np.linspace(i.min(), i.max(), working_ds.sizes[config.lon_name]),
config.lat_name: np.linspace(j.min(), j.max(), working_ds.sizes[config.lat_name]),
}
)
flipped_j = np.flip(j)
interpolated = working_ds.interp(
{
config.time_name: working_ds[config.time_name],
config.lat_name: flipped_j,
config.lon_name: i,
},
method=method,
)
interpolated = interpolated.where(ce_array > 0)
interpolated = interpolated.rename(
{
config.time_name: "time",
config.lat_name: "j",
config.lon_name: "i",
}
)
interpolated = interpolated.transpose("time", "j", "i")
interpolated["i"] = i.astype(np.int16)
interpolated["j"] = flipped_j.astype(np.int16)
interpolated = _append_ce_grid(interpolated, ce_grid)
return interpolated.assign_attrs(
interpolated=f"Interpolated using xarray.Dataset.interp with method='{method}'"
)
def _get_netcdf_grid_points(
ds: xr.Dataset,
ce_grid: gdal.Dataset,
watershed: ogr.DataSource,
config: NetCDFGridConfig,
) -> np.ndarray:
"""Find the gridded meteorological points that fall within the watershed extent."""
xtup, ytup, _ = u.GetExtent(ce_grid)
epsg_dem = projections.get_proj_code(ce_grid)
x, y = projections.utm_to_latlon(
(np.amin(xtup), np.amax(xtup)),
(np.amin(ytup), np.amax(ytup)),
epsg_dem,
)
dy = abs(ds[config.lat_name][0].values - ds[config.lat_name][1].values)
dx = abs(ds[config.lon_name][0].values - ds[config.lon_name][1].values)
lon_mask = np.arange(x[0], x[1], dx)
lat_mask = np.arange(y[0], y[1], dy)
watershed_layer = watershed.GetLayer()
xmin, xmax, ymin, ymax = watershed_layer.GetExtent()
x, y = projections.utm_to_latlon((xmin, xmax), (ymin, ymax), epsg_dem)
watershed_extent = (x, y)
return u.falls_in_extent(watershed_extent, lon_mask, lat_mask)
def _create_station_table(
ce_grid: gdal.Dataset,
dem: gdal.Dataset,
lon_utm: np.ndarray,
lat_utm: np.ndarray,
xy_pair: np.ndarray,
config: NetCDFGridConfig,
) -> pd.DataFrame:
"""Build the interpolation support table for the NetCDF grid points."""
ce_array = ce_grid.ReadAsArray()
row_count, col_count = ce_array.shape
i = np.arange(col_count, dtype=int)
j = np.flip(np.arange(row_count, dtype=int))
row, col = u.get_index_list(ce_grid, lon_utm, lat_utm)
row = np.array(row, dtype=np.float32)
col = np.array(col, dtype=np.float32)
valid_row = row < row_count
xy_pair = xy_pair[valid_row, :]
col = col[valid_row]
lat_utm = lat_utm[valid_row]
lon_utm = lon_utm[valid_row]
row = row[valid_row]
valid_col = col < col_count
xy_pair = xy_pair[valid_col, :]
row = row[valid_col]
lat_utm = lat_utm[valid_col]
lon_utm = lon_utm[valid_col]
col = col[valid_col]
row = row.astype(np.int16)
col = col.astype(np.int16)
return pd.DataFrame(
data={
"id": [f"NC-grid-{num}" for num in range(len(i[col]))],
"i": i[col] + config.ce_index_offset,
"j": j[row] + config.ce_index_offset,
"lat": xy_pair[:, 1],
"lon": xy_pair[:, 0],
"lat_utm": lat_utm,
"lon_utm": lon_utm,
"CEid": ce_array[row, col],
"altitude": u.get_altitude_point(dem, lat_utm, lon_utm),
}
)
def _append_ce_grid(ds: xr.Dataset, ce_grid: gdal.Dataset) -> xr.Dataset:
"""Attach the CE raster index layer to an interpolated meteorological dataset."""
grid = ce_grid.ReadAsArray().astype(np.float16)
grid[grid == 0] = np.nan
ce_dataset = xr.Dataset(
{
"CE": (
("j", "i"),
grid,
)
},
attrs={
"units": "-",
"long_name": "Whole squares",
"name": "CE",
},
)
return ds.assign(CE=ce_dataset["CE"])