Source code for gcpy.plot.compare_zonal_mean

#!/usr/bin/env python3
"""
Creates a six-panel comparison plot of zonal means from two different
GEOS-Chem model versions.  Called from the GEOS-Chem benchmarking scripts
and from the compare_diags.py example script.
"""
import os
import gc
import copy
import warnings
from multiprocessing import current_process
from tempfile import TemporaryDirectory
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import xarray as xr
from joblib import Parallel, delayed
from pypdf import PdfReader, PdfWriter
from gcpy.grid import get_vert_grid, get_pressure_indices, \
    pad_pressure_edges, convert_lev_to_pres
from gcpy.regrid import regrid_comparison_data, create_regridders, gen_xmat, \
    regrid_vertical
from gcpy.util import \
    get_molwt_from_metadata, reshape_MAPL_CS, get_diff_of_diffs, \
    all_zero_or_nan, compare_varnames, \
    read_species_metadata, verify_variable_type
from gcpy.units import check_units, data_unit_is_mol_per_mol
from gcpy.constants import MW_AIR_g, NO_STRETCH_SG_PARAMS
from gcpy.plot.core import gcpy_style, six_panel_subplot_names, \
    _warning_format, WhGrYlRd
from gcpy.plot.six_plot import six_plot

# Suppress numpy divide by zero warnings to prevent output spam
np.seterr(divide="ignore", invalid="ignore")

# Use a style sheet to control plot attributes
plt.style.use(gcpy_style)


[docs] def compare_zonal_mean( refdata, refstr, devdata, devstr, varlist=None, itime=0, refmet=None, devmet=None, weightsdir='.', pdfname="", cmpres=None, match_cbar=True, pres_range=None, normalize_by_area=False, enforce_units=True, convert_to_ugm3=False, spcdb_files=None, flip_ref=False, flip_dev=False, use_cmap_RdBu=False, verbose=False, log_color_scale=False, log_yaxis=False, extra_title_txt=None, n_job=-1, sigdiff_list=None, second_ref=None, second_dev=None, ref_vert_params=None, dev_vert_params=None, **extra_plot_args ): r""" Creates 3x2 comparison zonal-mean plots for variables common in two xarray Datasets. Optionally save to PDF. Parameters ---------- refdata : xarray.Dataset Dataset used as reference in comparison. refstr : str String description for reference data to be used in plots. devdata : xarray.Dataset Dataset used as development in comparison. devstr : str String description for development data to be used in plots. varlist : list of str, optional List of xarray dataset variable names to make plots for. Default value: None (will compare all common 3D variables) itime : int, optional Dataset time dimension index using 0-based system. Default value: 0 refmet : xarray.Dataset, optional Dataset containing ref meteorology. Default value: None devmet : xarray.Dataset, optional Dataset containing dev meteorology. Default value: None weightsdir : str, optional Directory path for storing regridding weights. Default value: None (will create/store weights in current directory) pdfname : str, optional File path to save plots as PDF. Default value: Empty string (will not create PDF) cmpres : str, optional String description of grid resolution at which to compare datasets. Default value: None (will compare at highest resolution of Ref and Dev) match_cbar : bool, optional Set this flag to True to use same the colorbar bounds for both Ref and Dev plots. Default value: True pres_range : list of int, optional Pressure range of levels to plot [hPa]. The vertical axis will span the outer pressure edges of levels that contain pres_range endpoints. Default value: [0, 2000] normalize_by_area : bool, optional Set this flag to True to to normalize raw data in both Ref and Dev datasets by grid area. Input ref and dev datasets must include AREA variable in m2 if normalizing by area. Default value: False enforce_units : bool, optional Set this flag to True force an error if the variables in the Ref and Dev datasets have different units. Default value: True convert_to_ugm3 : str, optional Whether to convert data units to ug/m3 for plotting. Default value: False spcdb_files : str or list, optional A single species_database.yml file or a list of files (e.g. for Ref & Dev). Only used when convert_to_ugm3=True. Default value: None flip_ref : bool, optional Set this flag to True to flip the vertical dimension of 3D variables in the Ref dataset. Default value: False flip_dev : bool, optional Set this flag to True to flip the vertical dimension of 3D variables in the Dev dataset. Default value: False use_cmap_RdBu : bool, optional Set this flag to True to use a blue-white-red colormap for plotting raw reference and development datasets. Default value: False verbose : bool, optional Set this flag to True to enable informative printout. Default value: False log_color_scale : bool, optional Set this flag to True to enable plotting data (not diffs) on a log color scale. Default value: False log_yaxis : bool, optional Set this flag to True if you wish to create zonal mean plots with a log-pressure Y-axis. Default value: False extra_title_txt : str, optional Specifies extra text (e.g. a date string such as "Jan2016") for the top-of-plot title. Default value: None n_job : int, optional Defines the number of simultaneous workers for parallel plotting. Set to 1 to disable parallel plotting. Value of -1 allows the application to decide. Default value: -1 sigdiff_list : list of str, optional Returns a list of all quantities having significant differences (where \|max(fractional difference)\| > 0.1). Default value: None second_ref : xarray.Dataset, optional A dataset of the same model type / grid as refdata, to be used in diff-of-diffs plotting. Default value: None second_dev : xarray.Dataset, optional A dataset of the same model type / grid as devdata, to be used in diff-of-diffs plotting. Default value: None ref_vert_params : list of array-like, optional Hybrid grid parameter A in hPa and B (unitless). Needed if ref grid is not 47 or 72 levels. Default value: None dev_vert_params : list of array-like, optional Hybrid grid parameter A in hPa and B (unitless). Needed if dev grid is not 47 or 72 levels. Default value: None **extra_plot_args Any extra keyword arguments are passed through the plotting functions to be used in calls to pcolormesh() (CS) or imshow() (Lat/Lon). """ warnings.showwarning = _warning_format verify_variable_type(refdata, xr.Dataset) verify_variable_type(devdata, xr.Dataset) # Create empty lists for keyword arguments if sigdiff_list is None: sigdiff_list = [] if ref_vert_params is None: ref_vert_params = [[], []] if dev_vert_params is None: dev_vert_params = [[], []] if pres_range is None: pres_range = [0, 2000] # Determine if doing diff-of-diffs diff_of_diffs = second_ref is not None and second_dev is not None # Prepare diff-of-diffs datasets if needed if diff_of_diffs: ## If needed, use fake time dim in case dates are different in datasets. ## This needs more work for case of single versus multiple times. #aligned_time = np.datetime64('2000-01-01') #refdata = refdata.assign_coords({'time' : [aligned_time]}) #devdata = devdata.assign_coords({'time' : [aligned_time]}) #second_ref = second_ref.assign_coords({'time' : [aligned_time]}) #second_dev = second_dev.assign_coords({'time' : [aligned_time]}) refdata, fracrefdata = get_diff_of_diffs(refdata, second_ref) devdata, fracdevdata = get_diff_of_diffs(devdata, second_dev) frac_refstr = 'GCC_dev / GCC_ref' frac_devstr = 'GCHP_dev / GCHP_ref' # If no varlist is passed, plot all 3D variables in the dataset if varlist is None: quiet = not verbose vardict = compare_varnames(refdata, devdata, quiet=quiet) varlist = vardict["commonvars3D"] print("Plotting all 3D variables") n_var = len(varlist) # Exit out if there are no 3D variables if not n_var: print("WARNING: no 3D variables to plot zonal mean for!") return # If no PDF name passed, then do not save to PDF savepdf = True if pdfname == "": savepdf = False # If converting to ug/m3, read species database file(s) so that # we can obtain molecular weights. if convert_to_ugm3: if spcdb_files is None: msg = "You must pass 'spcdb_files' when convert_to_ugm3=True!" raise ValueError(msg) ref_metadata, dev_metadata = read_species_metadata( spcdb_files, quiet= True ) # Get mid-point pressure and edge pressures for this grid ref_pedge, ref_pmid, _ = get_vert_grid(refdata, *ref_vert_params) dev_pedge, dev_pmid, _ = get_vert_grid(devdata, *dev_vert_params) # Get indexes of pressure subrange (full range is default) ref_pedge_ind = get_pressure_indices(ref_pedge, pres_range) dev_pedge_ind = get_pressure_indices(dev_pedge, pres_range) # Pad edges if subset does not include surface or TOA so data spans # entire subrange ref_pedge_ind = pad_pressure_edges( ref_pedge_ind, refdata.sizes["lev"], np.size(ref_pmid)) dev_pedge_ind = pad_pressure_edges( dev_pedge_ind, devdata.sizes["lev"], np.size(dev_pmid)) # pmid indexes do not include last pedge index ref_pmid_ind = ref_pedge_ind[:-1] dev_pmid_ind = dev_pedge_ind[:-1] # Convert levels to pressures in ref and dev data refdata = convert_lev_to_pres(refdata, ref_pmid, ref_pedge) devdata = convert_lev_to_pres(devdata, dev_pmid, dev_pedge) if diff_of_diffs: fracrefdata = convert_lev_to_pres(fracrefdata, ref_pmid, ref_pedge) fracdevdata = convert_lev_to_pres(fracdevdata, dev_pmid, dev_pedge) # ================================================================== # Reduce pressure range if reduced range passed as input. Indices # must be flipped if flipping vertical axis. # ================================================================== # this may require checking for 48 / 73 levels ref_pmid_ind_flipped = refdata.sizes["lev"] - ref_pmid_ind[::-1] - 1 dev_pmid_ind_flipped = devdata.sizes["lev"] - dev_pmid_ind[::-1] - 1 if flip_ref: ref_pmid_ind = ref_pmid_ind_flipped if flip_dev: dev_pmid_ind = dev_pmid_ind_flipped refdata = refdata.isel(lev=ref_pmid_ind) devdata = devdata.isel(lev=dev_pmid_ind) if diff_of_diffs: fracrefdata = fracrefdata.isel(lev=ref_pmid_ind) fracdevdata = fracdevdata.isel(lev=dev_pmid_ind) # Get stretched grid info, if any. # Parameter order is stretch factor, target longitude, target latitude. # Stretch factor 1 corresponds with no stretch. # Ref stretch attributes if 'stretch_factor' in refdata.attrs: sg_ref_params = [ refdata.attrs['stretch_factor'], refdata.attrs['target_longitude'], refdata.attrs['target_latitude']] elif 'STRETCH_FACTOR' in refdata.attrs: sg_ref_params = [ refdata.attrs['STRETCH_FACTOR'], refdata.attrs['TARGET_LON'], refdata.attrs['TARGET_LAT']] else: sg_ref_params = NO_STRETCH_SG_PARAMS # Dev stretch attributes if 'stretch_factor' in devdata.attrs: sg_dev_params = [ devdata.attrs['stretch_factor'], devdata.attrs['target_longitude'], devdata.attrs['target_latitude']] elif 'STRETCH_FACTOR' in devdata.attrs: sg_dev_params = [ devdata.attrs['STRETCH_FACTOR'], devdata.attrs['TARGET_LON'], devdata.attrs['TARGET_LAT']] else: sg_dev_params = NO_STRETCH_SG_PARAMS [refres, refgridtype, devres, devgridtype, cmpres, cmpgridtype, regridref, regriddev, regridany, refgrid, devgrid, cmpgrid, refregridder, devregridder, refregridder_list, devregridder_list] = \ create_regridders( refdata, devdata, weightsdir=weightsdir, cmpres=cmpres, zm=True, sg_ref_params=sg_ref_params, sg_dev_params=sg_dev_params ) # Use smaller vertical grid as target for vertical regridding # NOTE: Convert target_index from numpy.int64 to int to conform # to the Python style guide (as per Pylint). # -- Bob Yantosca (21 Sep 2023) target_index = int(np.array([len(ref_pedge), len(dev_pedge)]).argmin()) pedge = [ref_pedge, dev_pedge][target_index] pedge_ind = [ref_pedge_ind, dev_pedge_ind][target_index] # ================================================================== # Loop over all variables # ================================================================== ds_refs = [None] * n_var frac_ds_refs = [None] * n_var ds_devs = [None] * n_var frac_ds_devs = [None] * n_var for i in range(n_var): varname = varlist[i] # ================================================================== # Slice the data, allowing for no time dimension (bpch) # ================================================================== # Ref if "time" in refdata[varname].dims: ds_refs[i] = refdata[varname].isel(time=itime) if diff_of_diffs: frac_ds_refs[i] = fracrefdata[varname].isel(time=itime) else: ds_refs[i] = refdata[varname] if diff_of_diffs: frac_ds_refs[i] = fracrefdata[varname] # Dev if "time" in devdata[varname].dims: ds_devs[i] = devdata[varname].isel(time=itime) if diff_of_diffs: frac_ds_devs[i] = fracdevdata[varname].isel(time=itime) else: ds_devs[i] = devdata[varname] if diff_of_diffs: frac_ds_devs[i] = fracdevdata[varname] # ================================================================== # Handle units as needed # ================================================================== # Convert to ppb if units string is variation of mol/mol if data_unit_is_mol_per_mol(ds_refs[i]): ds_refs[i].values = ds_refs[i].values * 1e9 ds_refs[i].attrs["units"] = "ppb" if data_unit_is_mol_per_mol(ds_devs[i]): ds_devs[i].values = ds_devs[i].values * 1e9 ds_devs[i].attrs["units"] = "ppb" # If units string is ppbv (true for bpch data) then rename units if ds_refs[i].units.strip() == "ppbv": ds_refs[i].attrs["units"] = "ppb" if ds_devs[i].units.strip() == "ppbv": ds_devs[i].attrs["units"] = "ppb" # If units string is W/m2 (may be true for bpch data) then rename units if ds_refs[i].units.strip() == "W/m2": ds_refs[i].attrs["units"] = "W m-2" if ds_devs[i].units.strip() == "W/m2": ds_devs[i].attrs["units"] = "W m-2" # If units string is UNITLESS (may be true for bpch data) then rename # units if ds_refs[i].units.strip() == "UNITLESS": ds_refs[i].attrs["units"] = "1" if ds_devs[i].units.strip() == "UNITLESS": ds_devs[i].attrs["units"] = "1" # Compare units of ref and dev. The check_units function will throw an error # if the units do not match and enforce_units is True. check_units(ds_refs[i], ds_devs[i], enforce_units) # Convert from ppb to ug/m3 if convert_to_ugm3 is passed as true if convert_to_ugm3: # Error checks: must pass met, not normalize by area, and be in ppb if refmet is None or devmet is None: msg = "Met mata ust be passed to convert units to ug/m3." raise ValueError(msg) if normalize_by_area: msg = "Normalizing by area is now allowed if plotting ug/m3" raise ValueError(msg) if ds_refs[i].units != "ppb" or ds_devs[i].units != "ppb": msg = "Units must be mol/mol if converting to ug/m3." raise ValueError(msg) # Slice air density data by time and lev # (assume same format and dimensions as refdata and devdata) if "time" in refmet["Met_AIRDEN"].dims: ref_airden = refmet["Met_AIRDEN"].isel(time=itime, lev=ref_pmid_ind) else: ref_airden = refmet["Met_AIRDEN"].isel(lev=ref_pmid_ind) if "time" in devmet["Met_AIRDEN"].dims: dev_airden = devmet["Met_AIRDEN"].isel(time=itime, lev=dev_pmid_ind) else: dev_airden = devmet["Met_AIRDEN"].isel(lev=dev_pmid_ind) # Get the species molecular weights from Ref & Dev metadata spc_name = varname.replace(varname.split("_")[0] + "_", "") ref_spc_mw_g = get_molwt_from_metadata(ref_metadata, spc_name) dev_spc_mw_g = get_molwt_from_metadata(dev_metadata, spc_name) # Skip if the species has no molecular weight in # both Ref & Dev species metadata if ref_spc_mw_g is None and dev_spc_mw_g is None: msg = f"Cannot convert {spc_name} to ug/m3! " msg +="no molecular weight found in Ref & Dev metadata!" continue # If only one of the species has no molecular weight # print a warning message but allow comparison to proceed if ref_spc_mw_g is None or dev_spc_mw_g is None: msg = f"Cannot convert {spc_name} to ug/m3!, " msg +="no molecular weight was found!" print(msg) # Convert values from ppb to ug/m3: # ug/m3 = 1e-9ppb * mol/g air * kg/m3 air * 1e3g/kg # * g/mol spc * 1e6ug/g # = ppb * air density * (spc MW / air MW) # # If mol. wt. is missing, then set data to NaN if ref_spc_mw_g is not None: ds_refs[i].values *= \ ref_airden.values * (ref_spc_mw_g / MW_AIR_g) else: ds_refs[i].values *= np.nan if dev_spc_mw_g is not None: ds_devs[i].values *= \ dev_airden.values * (dev_spc_mw_g / MW_AIR_g) else: ds_devs[i].values *= np.nan # Update units string ds_refs[i].attrs["units"] = "\u03BCg/m3" # ug/m3 using mu ds_devs[i].attrs["units"] = "\u03BCg/m3" # ============================================================== # Reshape cubed sphere data if using MAPL v1.0.0+ # TODO: update function to expect data in this format # ============================================================== ds_refs[i] = reshape_MAPL_CS(ds_refs[i]) ds_devs[i] = reshape_MAPL_CS(ds_devs[i]) if diff_of_diffs: frac_ds_refs[i] = reshape_MAPL_CS(frac_ds_refs[i]) frac_ds_devs[i] = reshape_MAPL_CS(frac_ds_devs[i]) # Flip in the vertical if applicable if flip_ref: ds_refs[i].data = ds_refs[i].data[::-1, :, :] if diff_of_diffs: frac_ds_refs[i].data = frac_ds_refs[i].data[::-1, :, :] if flip_dev: ds_devs[i].data = ds_devs[i].data[::-1, :, :] if diff_of_diffs: frac_ds_devs[i].data = frac_ds_devs[i].data[::-1, :, :] # ================================================================== # Get the area variables if normalize_by_area=True. They can be # either in the main datasets as variable AREA or in the optionally # passed meteorology datasets as Met_AREAM2. # ================================================================== if normalize_by_area: if "AREA" in refdata.data_vars.keys(): ref_area = refdata["AREA"] elif refmet is not None: if "Met_AREAM2" in refmet.data_vars.keys(): ref_area = refmet["Met_AREAM2"] else: msg = "normalize_by_area = True but AREA not " \ + "present in the Ref dataset and ref met with Met_AREAM2" \ + " not passed!" raise ValueError(msg) if "time" in ref_area.dims: ref_area = ref_area.isel(time=0) if refgridtype == 'cs': ref_area = reshape_MAPL_CS(ref_area) if "AREA" in devdata.data_vars.keys(): dev_area = devdata["AREA"] elif devmet is not None: if "Met_AREAM2" in devmet.data_vars.keys(): dev_area = devmet["Met_AREAM2"] else: msg = "normalize_by_area = True but AREA not " \ + "present in the Dev dataset and dev met with Met_AREAM2" \ | " not passed!" raise ValueError(msg) if "time" in dev_area.dims: dev_area = dev_area.isel(time=0) if devgridtype == 'cs': dev_area = reshape_MAPL_CS(dev_area) # Make sure the areas do not have a lev dimension if "lev" in ref_area.dims: ref_area = ref_area.isel(lev=0) if "lev" in dev_area.dims: dev_area = dev_area.isel(lev=0) # ================================================================== # Create arrays for each variable in the Ref and Dev dataset # and regrid to the comparison grid. # ================================================================== ds_ref_cmps = [None] * n_var ds_dev_cmps = [None] * n_var frac_ds_ref_cmps = [None] * n_var frac_ds_dev_cmps = [None] * n_var # store units in case data changes from DataArray to numpy array ref_units = [None] * n_var dev_units = [None] * n_var # regrid vertically if necessary if len(ref_pedge) != len(pedge): xmat = gen_xmat(ref_pedge[ref_pedge_ind], pedge[pedge_ind]) elif len(dev_pedge) != len(pedge): xmat = gen_xmat(dev_pedge[dev_pedge_ind], pedge[pedge_ind]) for i in range(n_var): ds_ref = ds_refs[i] ds_dev = ds_devs[i] frac_ds_ref = frac_ds_refs[i] frac_ds_dev = frac_ds_devs[i] # Do area normalization before regridding if normalize_by_area=True if normalize_by_area: exclude_list = ["WetLossConvFrac", "Prod_", "Loss_"] if not any(s in varname for s in exclude_list): ds_ref.values = ds_ref.values / ref_area.values ds_dev.values = ds_dev.values / dev_area.values ds_refs[i] = ds_ref ds_devs[i] = ds_dev if diff_of_diffs: frac_ds_ref.values = frac_ds_ref.values / ref_area.values frac_ds_refs[i] = frac_ds_ref frac_ds_dev.values = frac_ds_dev.values / dev_area.values frac_ds_devs[i] = frac_ds_dev # save units for later use ref_units[i] = ds_ref.attrs["units"] dev_units[i] = ds_dev.attrs["units"] ref_nlev = len(ds_ref['lev']) dev_nlev = len(ds_dev['lev']) # Regrid variables horizontally # Ref ds_ref = regrid_comparison_data( ds_ref, refres, regridref, refregridder, refregridder_list, cmpgrid, refgridtype, cmpgridtype, nlev=ref_nlev ) if diff_of_diffs: frac_ds_ref = regrid_comparison_data( frac_ds_ref, refres, regridref, refregridder, refregridder_list, cmpgrid, cmpgridtype, refgridtype, nlev=ref_nlev ) # Dev ds_dev = regrid_comparison_data( ds_dev, devres, regriddev, devregridder, devregridder_list, cmpgrid, devgridtype, cmpgridtype, nlev=dev_nlev ) if diff_of_diffs: frac_ds_dev = regrid_comparison_data( frac_ds_dev, devres, regriddev, devregridder, devregridder_list, cmpgrid, devgridtype, cmpgridtype, nlev=dev_nlev ) # store regridded CS data before dealing with vertical regridding if refgridtype == "cs": ds_refs[i] = ds_ref frac_ds_refs[i] = frac_ds_ref if devgridtype == "cs": ds_devs[i] = ds_dev frac_ds_devs[i] = frac_ds_dev # Reduce variables to smaller vert grid if necessary for comparison if len(ref_pedge) != len(pedge): ds_ref = regrid_vertical(ds_ref, xmat, dev_pmid[dev_pmid_ind]) if diff_of_diffs: frac_ds_ref = regrid_vertical(frac_ds_ref, xmat, dev_pmid[dev_pmid_ind]) if len(dev_pedge) != len(pedge): ds_dev = regrid_vertical(ds_dev, xmat, ref_pmid[ref_pmid_ind]) if diff_of_diffs: frac_ds_dev = regrid_vertical(frac_ds_dev, xmat, ref_pmid[ref_pmid_ind]) ds_ref_cmps[i] = ds_ref ds_dev_cmps[i] = ds_dev if diff_of_diffs: frac_ds_ref_cmps[i] = frac_ds_ref frac_ds_dev_cmps[i] = frac_ds_dev # Force garbage collection manually (frees memory) del refregridder, refregridder_list, devregridder, devregridder_list gc.collect() # Universal plot setup xtick_positions = np.arange(-90, 91, 30) xticklabels = [rf"{x}$\degree$" for x in xtick_positions] # ================================================================== # Define function to create a single page figure to be called # in a parallel loop # ================================================================== def createfig(ivar, temp_dir=''): # Suppress harmless run-time warnings (mostly about underflow) warnings.filterwarnings('ignore', category=RuntimeWarning) warnings.filterwarnings('ignore', category=UserWarning) if savepdf and verbose: print(f"{ivar} ", end="") varname = varlist[ivar] # ============================================================== # Assign data variables # ============================================================== ds_ref = ds_refs[ivar] ds_dev = ds_devs[ivar] ds_ref_cmp = ds_ref_cmps[ivar] ds_dev_cmp = ds_dev_cmps[ivar] frac_ds_ref_cmp = frac_ds_ref_cmps[ivar] frac_ds_dev_cmp = frac_ds_dev_cmps[ivar] # ============================================================== # Area normalization units and subtitle # Set units and subtitle, including modification if normalizing # area. Note if enforce_units is False (non-default) then # units on difference plots will be wrong. # ============================================================== cmn_units = ref_units[ivar] subtitle_extra = "" if normalize_by_area: exclude_list = ["WetLossConvFrac", "Prod_", "Loss_"] if not any(s in varname for s in exclude_list): if "/" in cmn_units: cmn_units = f"{cmn_units}/m2" else: cmn_units = f"{cmn_units} m-2" ref_units[ivar] = cmn_units dev_units[ivar] = cmn_units subtitle_extra = ", Normalized by Area" # ============================================================== # Calculate zonal mean # ============================================================== # Ref if refgridtype == "ll": zm_ref = ds_ref.mean(dim="lon") else: zm_ref = ds_ref.mean(axis=2) # Dev if devgridtype == "ll": zm_dev = ds_dev.mean(dim="lon") else: zm_dev = ds_dev.mean(axis=2) # Comparison zm_dev_cmp = ds_dev_cmp.mean(axis=2) zm_ref_cmp = ds_ref_cmp.mean(axis=2) if diff_of_diffs: frac_zm_dev_cmp = frac_ds_dev_cmp.mean(axis=2) frac_zm_ref_cmp = frac_ds_ref_cmp.mean(axis=2) # ============================================================== # Get min and max values for use in the top-row plot colorbars # and also flag if Ref and/or Dev are all zero or all NaN. # ============================================================== # Ref vmin_ref = float(zm_ref.min()) vmax_ref = float(zm_ref.max()) # Dev vmin_dev = float(zm_dev.min()) vmax_dev = float(zm_dev.max()) # Set vmin_both and vmax_both to use if match_cbar=True vmin_both = np.min([vmin_ref, vmin_dev]) vmax_both = np.max([vmax_ref, vmax_dev]) # ============================================================== # Test if Ref and/or Dev contain all zeroes or all NaNs. # This will have implications as to how we set min and max # values for the color ranges below. # ============================================================== ref_values = ds_ref.values if isinstance(ds_ref, xr.DataArray) else ds_ref dev_values = ds_dev.values if isinstance(ds_dev, xr.DataArray) else ds_dev ref_is_all_zero, ref_is_all_nan = all_zero_or_nan(ref_values) dev_is_all_zero, dev_is_all_nan = all_zero_or_nan(dev_values) # ============================================================== # Calculate zonal mean difference # ============================================================== zm_diff = np.array(zm_dev_cmp) - np.array(zm_ref_cmp) # Test if abs. diff is zero everywhere or NaN everywhere absdiff_is_all_zero, absdiff_is_all_nan = all_zero_or_nan(zm_diff) # ============================================================== # Calculate fractional difference, set divides by zero to Nan # ============================================================== if diff_of_diffs: zm_fracdiff = np.array(frac_zm_dev_cmp) - \ np.array(frac_zm_ref_cmp) else: zm_fracdiff = np.abs(np.array(zm_dev_cmp)) / \ np.abs(np.array(zm_ref_cmp)) zm_fracdiff = np.where(np.abs(zm_fracdiff) == np.inf, np.nan, zm_fracdiff) zm_fracdiff[zm_fracdiff > 1e308] = np.nan # Test if the frac. diff is zero everywhere or NaN everywhere fracdiff_is_all_zero = not np.any(zm_fracdiff) or \ (np.nanmin(zm_fracdiff) == 0 and np.nanmax(zm_fracdiff) == 0) fracdiff_is_all_nan = np.isnan(zm_fracdiff).all() # ============================================================== # Create 3x2 figure # ============================================================== # Create figs and axes objects figs, ((ax0, ax1), (ax2, ax3), (ax4, ax5)) = plt.subplots( 3, 2, figsize=[12, 15.3] ) # Add extra adding so that plots don't bump into each other. # For zonal mean plots, we need to leave extra padding at the # left (for the Y-axis label) and at the bottom (for the colrobar). plt.subplots_adjust( left=0.10, # Fraction of page width, from left edge right=0.925, # Fraction of page width, from left edge bottom=0.05, # Fraction of page height, from bottom edge wspace=0.25, # Horizontal spacing btw subplots (frac of width) hspace=0.35 # Vertical spacing btw subplots (fract of height) ) # Give the plot a title offset = 0.96 if extra_title_txt is not None: figs.suptitle( f"{varname}, Zonal Mean ({extra_title_txt})", y=offset, ) else: figs.suptitle( f"{varname}, Zonal Mean", y=offset ) # ============================================================== # Set color map objects. Use gray for NaNs (no worries, # because zonal means are always plotted on lat-alt grids). # # Use shallow copy (copy.copy() to create color map objects, # in order to avoid set_bad() from being applied to the base # color table. See: https://docs.python.org/3/library/copy.html # ============================================================== if use_cmap_RdBu: cmap1 = copy.copy(mpl.colormaps["RdBu_r"]) else: cmap1 = copy.copy(WhGrYlRd) cmap1.set_bad("gray") cmap_plot = copy.copy(mpl.colormaps["RdBu_r"]) cmap_plot.set_bad(color="gray") # ============================================================== # Set titles for plots # ============================================================== if refgridtype == "ll": ref_title = f"{refstr} (Ref){subtitle_extra}\n{refres}" else: ref_title = f"{refstr} (Ref){subtitle_extra}\n{cmpres} regridded from c{refres}" if devgridtype == "ll": dev_title = f"{devstr} (Dev){subtitle_extra}\n{devres}" else: dev_title = f"{devstr} (Dev){subtitle_extra}\n{cmpres} regridded from c{devres}" if regridany: absdiff_dynam_title = \ f"Difference ({cmpres})\nDev - Ref, Dynamic Range" absdiff_fixed_title = \ f"Difference ({cmpres})\nDev - Ref, Restricted Range [5%,95%]" if diff_of_diffs: fracdiff_dynam_title = \ f"Difference ({cmpres}), " + \ f"Dynamic Range\n{frac_devstr} - {frac_refstr}" fracdiff_fixed_title = \ f"Difference ({cmpres}), " + \ f"Restricted Range [5%,95%]\n{frac_devstr} - {frac_refstr}" else: fracdiff_dynam_title = \ f"Ratio ({cmpres})\nDev/Ref, Dynamic Range" fracdiff_fixed_title = \ f"Ratio ({cmpres})\nDev/Ref, Fixed Range" else: absdiff_dynam_title = "Difference\nDev - Ref, Dynamic Range" absdiff_fixed_title = \ "Difference\nDev - Ref, Restricted Range [5%,95%]" if diff_of_diffs: fracdiff_dynam_title = \ f"Difference, Dynamic Range\n{frac_devstr} - {frac_refstr}" fracdiff_fixed_title = \ "Difference, Restricted Range " + \ f"[5%,95%]\n{frac_devstr} - {frac_refstr}" else: fracdiff_dynam_title = "Ratio \nDev/Ref, Dynamic Range" fracdiff_fixed_title = "Ratio \nDev/Ref, Fixed Range" # ============================================================== # Bundle variables for 6 parallel plotting calls # 0 = Ref 1 = Dev # 2 = Dynamic abs diff 3 = Restricted abs diff # 4 = Dynamic frac diff 5 = Restricted frac diff # ============================================================== subplots = six_panel_subplot_names(diff_of_diffs) all_zeros = [ ref_is_all_zero, dev_is_all_zero, absdiff_is_all_zero, absdiff_is_all_zero, fracdiff_is_all_zero, fracdiff_is_all_zero, ] all_nans = [ ref_is_all_nan, dev_is_all_nan, absdiff_is_all_nan, absdiff_is_all_nan, fracdiff_is_all_nan, fracdiff_is_all_nan, ] plot_vals = [zm_ref, zm_dev, zm_diff, zm_diff, zm_fracdiff, zm_fracdiff] axs = [ax0, ax1, ax2, ax3, ax4, ax5] cmaps = [cmap1, cmap1, cmap_plot, cmap_plot, cmap_plot, cmap_plot] rowcols = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] titles = [ ref_title, dev_title, absdiff_dynam_title, absdiff_fixed_title, fracdiff_dynam_title, fracdiff_fixed_title, ] grids = [refgrid, devgrid, cmpgrid, cmpgrid, cmpgrid, cmpgrid] if refgridtype != "ll": grids[0] = cmpgrid if devgridtype != "ll": grids[1] = cmpgrid extents = [None, None, None, None, None, None] masked = ["ZM", "ZM", "ZM", "ZM", "ZM", "ZM"] unit_list = [ref_units[ivar], dev_units[ivar], cmn_units, cmn_units, "unitless", "unitless"] other_all_nans = [dev_is_all_nan, ref_is_all_nan, False, False, False, False] gridtypes = [ cmpgridtype, cmpgridtype, cmpgridtype, cmpgridtype, cmpgridtype, cmpgridtype, ] pedges = [ref_pedge, dev_pedge, pedge, pedge, pedge, pedge] pedge_inds = [ref_pedge_ind, dev_pedge_ind, pedge_ind, pedge_ind, pedge_ind, pedge_ind] mins = [vmin_ref, vmin_dev, vmin_both] maxs = [vmax_ref, vmax_dev, vmax_both] ratio_logs = [False, False, False, False, True, True] # Plot for i in range(6): six_plot( subplots[i], all_zeros[i], all_nans[i], plot_vals[i], grids[i], axs[i], rowcols[i], titles[i], cmaps[i], unit_list[i], extents[i], masked[i], other_all_nans[i], gridtypes[i], mins, maxs, use_cmap_RdBu, match_cbar, verbose, log_color_scale, pedges[i], pedge_inds[i], log_yaxis, plot_type="zonal_mean", xtick_positions=xtick_positions, xticklabels=xticklabels, ratio_log=ratio_logs[i], **extra_plot_args ) # ============================================================== # Add this page of 6-panel plots to the PDF file # ============================================================== if savepdf: folders = pdfname.split('/') pdfname_temp = folders[-1] + "BENCHMARKFIGCREATION.pdf" + str(ivar) full_path = temp_dir for folder in folders[:-1]: full_path = os.path.join(full_path, folder) if not os.path.isdir(full_path): try: os.mkdir(full_path) except FileExistsError: pass pdf = PdfPages(os.path.join(full_path, pdfname_temp)) pdf.savefig(figs) pdf.close() plt.close(figs) # ============================================================== # Update the list of variables with significant differences. # Criterion: abs(1 - max(fracdiff)) > 0.1 # Do not include NaNs in the criterion, because these indicate # places where fracdiff could not be computed (div-by-zero). # ============================================================== if np.abs(1 - np.nanmax(zm_fracdiff)) > 0.1: sigdiff_list.append(varname) return varname return "" # ================================================================== # Call figure generation function in a parallel loop over variables # # ================================================================== # Disable parallelization if this routine is already being # called in parallel. This is due to issues with matplotlib. if current_process().name != "MainProcess": n_job = 1 if not savepdf: # disable parallel plotting to allow interactive figure plotting for i in range(n_var): createfig(i) else: with TemporaryDirectory() as temp_dir: # --------------------------------------- # Turn off parallelization if n_job=1 if n_job != 1: results = Parallel(n_jobs=n_job)( delayed(createfig)(i, temp_dir) for i in range(n_var) ) else: results = [] for i in range(n_var): results.append(createfig(i, temp_dir)) # --------------------------------------- # update sig diffs after parallel calls if current_process().name == "MainProcess": for varname in results: if isinstance(varname, str): sigdiff_list.append(varname) # ========================================================== # Finish # ========================================================== # Close the PDF object pdf = PdfPages(pdfname) pdf.close() if verbose: print("Closed PDF") # Concatenate individual PDFs together # Now use PdfWriter instead of PdfMerger writer = PdfWriter() for i in range(n_var): temp_pdfname = pdfname if pdfname[0] == '/': temp_pdfname = temp_pdfname[1:] temp_pdfname = os.path.join( str(temp_dir), f"{temp_pdfname}BENCHMARKFIGCREATION.pdf{str(i)}" ) reader = PdfReader(temp_pdfname) for page in reader.pages: writer.add_page(page) # Write combined PDF with open(pdfname, "wb") as ofile: writer.write(ofile) if verbose: print(f"Created {pdfname} for {n_var} variables") warnings.showwarning = _warning_format