Skip to content

Commit

Permalink
Add 2d variance metrics to reservoir training (#2361)
Browse files Browse the repository at this point in the history
This adds scalar metrics for the mean vertically-summed grid-scale
variance of outputs in the x, y plane. Since the prognostic run
reservoir predictions have issues with too much grid-scale noise in
column-integrated quantities, I would like to see how the
hyperparameters affect this in offline evaluation. I don't have area or
pressure thicknesses saved in the data, so this is a very rough way of
estimated the variance in column-integrated quantities.

During synchronization of the reservoir the `_rc_out` precipitable water
field has a higher variance than the `_hyb_in` field, which suggests
that this should be visible in offline evaluation.


![image](https://github.com/ai2cm/fv3net/assets/16710132/a896a961-4324-4a51-a693-fec05b0e880a)
  • Loading branch information
AnnaKwa committed Nov 3, 2023
1 parent 91a580c commit 314edd5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
8 changes: 7 additions & 1 deletion external/fv3fit/fv3fit/reservoir/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
)
from .adapters import ReservoirDatasetAdapter, HybridReservoirDatasetAdapter
from .domain2 import RankXYDivider
from .validation import validation_prediction, log_rmse_z_plots, log_rmse_scalar_metrics
from .validation import (
validation_prediction,
log_rmse_z_plots,
log_rmse_scalar_metrics,
log_variance_scalar_metrics,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -285,6 +290,7 @@ def train_reservoir_model(
)
log_rmse_z_plots(ds_val, model.output_variables)
log_rmse_scalar_metrics(ds_val, model.output_variables)
log_variance_scalar_metrics(ds_val, model.output_variables)
except Exception as e:
logging.error("Error logging validation metrics to wandb", exc_info=e)
return adapter
Expand Down
55 changes: 51 additions & 4 deletions external/fv3fit/fv3fit/reservoir/validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy.ndimage import generic_filter
from typing import Union, Optional, Sequence
import xarray as xr
import tensorflow as tf
Expand All @@ -20,6 +21,28 @@
ReservoirAdapter = Union[HybridReservoirDatasetAdapter, ReservoirDatasetAdapter]


def _variance_2d(slice_2d):
"""Applies the standard deviation over a 2D slice."""
return generic_filter(slice_2d, np.var, size=(3, 3), mode="reflect")


def _compute_2d_variance_mean_zsum(arr):
"""
Rough estimates for grid-scale spatial variance of column-integrated
quantities in the absence of pressure thickness and area information.
"""
variance_2d = xr.apply_ufunc(
_variance_2d,
arr,
input_core_dims=[["x", "y"]],
output_core_dims=[["x", "y"]],
vectorize=True,
)
if "z" in variance_2d.dims:
variance_2d = variance_2d.sum("z")
return variance_2d.mean().item()


def _get_predictions_over_batch(
model: ReservoirModel,
states_with_overlap_time_series: Sequence[np.ndarray],
Expand Down Expand Up @@ -110,6 +133,8 @@ def validation_prediction(
one_step_predictions = np.array(one_step_prediction_time_series)[n_synchronize:-1]
time_means_to_calculate = {
"time_mean_prediction": one_step_predictions,
"time_mean_persistence": persistence,
"time_mean_target": target,
"time_mean_prediction_error": one_step_predictions - target,
"time_mean_persistence_error": persistence - target,
"time_mean_prediction_mse": (one_step_predictions - target) ** 2,
Expand All @@ -134,7 +159,6 @@ def validation_prediction(
diags_ = []
for key, data in time_means_to_calculate.items():
diags_.append(_time_mean_dataset(model.input_variables, data, key))

return xr.merge(diags_)


Expand All @@ -152,11 +176,14 @@ def log_rmse_z_plots(ds_val, variables):
# will need to change this.
for var in variables:
rmse = {}
for comparison in ["persistence", "prediction", "imperfect_prediction"]:
for comparison in [
"persistence",
"imperfect_prediction",
"prediction",
]:
mse_key = f"time_mean_{comparison}_mse_{var}"
if mse_key in ds_val:
rmse[comparison] = np.sqrt(ds_val[mse_key].mean(["x", "y"])).values

wandb.log(
{
f"val_rmse_zplot_{var}": wandb.plot.line_series(
Expand All @@ -170,6 +197,27 @@ def log_rmse_z_plots(ds_val, variables):
)


def log_variance_scalar_metrics(ds_val, variables):
log_data = {}
for var in variables:
for comparison in [
"target",
"prediction",
]:
key = f"time_mean_{comparison}_{var}"
if key in ds_val:
variance_key = f"time_mean_{comparison}_2d_variance_zsum_{var}"
log_data[variance_key] = _compute_2d_variance_mean_zsum(ds_val[key])
try:
log_data[f"variance_ratio_{var}"] = (
log_data[f"time_mean_prediction_2d_variance_zsum_{var}"]
/ log_data[f"time_mean_target_2d_variance_zsum_{var}"]
)
except (KeyError):
pass
wandb.log(log_data)


def log_rmse_scalar_metrics(ds_val, variables):
scaled_errors_persistence, scaled_errors_imperfect = [], []
for var in variables:
Expand Down Expand Up @@ -202,5 +250,4 @@ def log_rmse_scalar_metrics(ds_val, variables):
)
except (KeyError):
pass

wandb.log(log_data)

0 comments on commit 314edd5

Please sign in to comment.