Source code for gcpy.plot.compare_single_level

#!/usr/bin/env python3
"""
compare_single_level.py: Function to create a six-panel plot comparing
quantities at a single model level for two different model versions.
Called from the GEOS-Chem benchmarking scripts and from the
compare_diags.py example script.
"""
import os
import gc
import copy
import warnings
from multiprocessing import current_process
from tempfile import TemporaryDirectory
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
from joblib import Parallel, delayed
from pypdf import PdfReader, PdfWriter
from gcpy.grid import get_grid_extents, call_make_grid
from gcpy.regrid import regrid_comparison_data, create_regridders
from gcpy.util import \
    reshape_MAPL_CS, get_diff_of_diffs, get_molwt_from_metadata, \
    all_zero_or_nan, slice_by_lev_and_time, compare_varnames, \
    read_species_metadata, verify_variable_type
from gcpy.units import check_units, data_unit_is_mol_per_mol
from gcpy.constants import MW_AIR_g, NO_STRETCH_SG_PARAMS
from gcpy.plot.core import gcpy_style, six_panel_subplot_names, \
    _warning_format, WhGrYlRd
from gcpy.plot.six_plot import six_plot

# Suppress numpy divide by zero warnings to prevent output spam
np.seterr(divide="ignore", invalid="ignore")

# Use a style sheet to control plot attributes
plt.style.use(gcpy_style)


[docs] def compare_single_level( refdata, refstr, devdata, devstr, varlist=None, ilev=0, itime=0, refmet=None, devmet=None, weightsdir='.', pdfname="", cmpres=None, match_cbar=True, normalize_by_area=False, enforce_units=True, convert_to_ugm3=False, flip_ref=False, flip_dev=False, use_cmap_RdBu=False, verbose=False, log_color_scale=False, extra_title_txt=None, extent=None, n_job=-1, sigdiff_list=None, second_ref=None, second_dev=None, spcdb_files=None, ll_plot_func='imshow', **extra_plot_args ): r""" Create single-level 3x2 comparison map plots for variables common in two xarray Datasets. Optionally save to PDF. Parameters ---------- refdata : xarray.Dataset Dataset used as reference in comparison. refstr : str String description for reference data to be used in plots. devdata : xarray.Dataset Dataset used as development in comparison. devstr : str String description for development data to be used in plots. varlist : list of str, optional List of xarray dataset variable names to make plots for. Default value: None (will compare all common variables) ilev : int, optional Dataset level dimension index using 0-based system. Indexing is ambiguous when plotting differing vertical grids. Default value: 0 itime : int, optional Dataset time dimension index using 0-based system. Default value: 0 refmet : xarray.Dataset, optional Dataset containing ref meteorology. Default value: None devmet : xarray.Dataset, optional Dataset containing dev meteorology. Default value: None weightsdir : str, optional Directory path for storing regridding weights. Default value: None (will create/store weights in current directory) pdfname : str, optional File path to save plots as PDF. Default value: Empty string (will not create PDF) cmpres : str, optional String description of grid resolution at which to compare datasets. Default value: None (will compare at highest resolution of ref and dev) match_cbar : bool, optional Set this flag to True if you wish to use the same colorbar bounds for the Ref and Dev plots. Default value: True normalize_by_area : bool, optional Set this flag to True if you wish to normalize the Ref and Dev raw data by grid area. Input ref and dev datasets must include AREA variable in m2 if normalizing by area. Default value: False enforce_units : bool, optional Set this flag to True to force an error if Ref and Dev variables have different units. Default value: True convert_to_ugm3 : bool, optional Whether to convert data units to ug/m3 for plotting. Default value: False spcdb_files : str or list, optional A single species_database.yml file or a list of files (e.g. for Ref & Dev). Only used when convert_ugm3=True. Default value: None flip_ref : bool, optional Set this flag to True to flip the vertical dimension of 3D variables in the Ref dataset. Default value: False flip_dev : bool, optional Set this flag to True to flip the vertical dimension of 3D variables in the Dev dataset. Default value: False use_cmap_RdBu : bool, optional Set this flag to True to use a blue-white-red colormap for plotting the raw data in both the Ref and Dev datasets. Default value: False verbose : bool, optional Set this flag to True to enable informative printout. Default value: False log_color_scale : bool, optional Set this flag to True to plot data (not diffs) on a log color scale. Default value: False extra_title_txt : str, optional Specifies extra text (e.g. a date string such as "Jan2016") for the top-of-plot title. Default value: None extent : list, optional Defines the extent of the region to be plotted in form [minlon, maxlon, minlat, maxlat]. Default value plots extent of input grids. Default value: [-1000, -1000, -1000, -1000] n_job : int, optional Defines the number of simultaneous workers for parallel plotting. Set to 1 to disable parallel plotting. Value of -1 allows the application to decide. Default value: -1 sigdiff_list : list of str, optional Returns a list of all quantities having significant differences (where \|max(fractional difference)\| > 0.1). Default value: None second_ref : xarray.Dataset, optional A dataset of the same model type / grid as refdata, to be used in diff-of-diffs plotting. Default value: None second_dev : xarray.Dataset, optional A dataset of the same model type / grid as devdata, to be used in diff-of-diffs plotting. Default value: None ll_plot_func : str, optional Function to use for lat/lon single level plotting with possible values 'imshow' and 'pcolormesh'. imshow is much faster but is slightly displaced when plotting from dateline to dateline and/or pole to pole. Default value: 'imshow' **extra_plot_args Any extra keyword arguments are passed through the plotting functions to be used in calls to pcolormesh() (CS) or imshow() (Lat/Lon). """ warnings.showwarning = _warning_format # Error check arguments verify_variable_type(refdata, xr.Dataset) verify_variable_type(devdata, xr.Dataset) # Create empty lists for keyword arguments if extent is None: extent = [-1000, -1000, -1000, -1000] if sigdiff_list is None: sigdiff_list = [] # Determine if doing diff-of-diffs diff_of_diffs = second_ref is not None and second_dev is not None # Prepare diff-of-diffs datasets if needed if diff_of_diffs: ## If needed, use fake time dim in case dates are different ## in datasets. This needs more work for case of single versus ## multiple times. #aligned_time = [np.datetime64('2000-01-01')] * refdata.dims['time'] #refdata = refdata.assign_coords({'time': aligned_time}) #devdata = devdata.assign_coords({'time': aligned_time}) #second_ref = second_ref.assign_coords({'time': aligned_time}) #second_dev = second_dev.assign_coords({'time': aligned_time}) refdata, fracrefdata = get_diff_of_diffs(refdata, second_ref) devdata, fracdevdata = get_diff_of_diffs(devdata, second_dev) frac_refstr = 'GCC_dev / GCC_ref' frac_devstr = 'GCHP_dev / GCHP_ref' # If no varlist is passed, plot all (surface only for 3D) if varlist is None: quiet = not verbose vardict = compare_varnames(refdata, devdata, quiet=quiet) varlist = vardict["commonvars3D"] + vardict["commonvars2D"] print("Plotting all common variables") n_var = len(varlist) # If no PDF name passed, then do not save to PDF savepdf = True if pdfname == "": savepdf = False # If converting to ug/m3, read species database file(s) to obtain # molecular weights. if convert_to_ugm3: if spcdb_files is None: msg = "You must pass 'spcdb_files' when convert_to_ugm3=True!" raise ValueError(msg) ref_metadata, dev_metadata = read_species_metadata( spcdb_files, quiet= True ) # Get stretched grid info, if any. # Parameter order is stretch factor, target longitude, target latitude. # Stretch factor 1 corresponds with no stretch. # Ref stretch attributes if 'stretch_factor' in refdata.attrs: sg_ref_params = [ refdata.attrs['stretch_factor'], refdata.attrs['target_longitude'], refdata.attrs['target_latitude']] elif 'STRETCH_FACTOR' in refdata.attrs: sg_ref_params = [ refdata.attrs['STRETCH_FACTOR'], refdata.attrs['TARGET_LON'], refdata.attrs['TARGET_LAT']] else: sg_ref_params = NO_STRETCH_SG_PARAMS # Dev stretch attributes if 'stretch_factor' in devdata.attrs: sg_dev_params = [ devdata.attrs['stretch_factor'], devdata.attrs['target_longitude'], devdata.attrs['target_latitude']] elif 'STRETCH_FACTOR' in devdata.attrs: sg_dev_params = [ devdata.attrs['STRETCH_FACTOR'], devdata.attrs['TARGET_LON'], devdata.attrs['TARGET_LAT']] else: sg_dev_params = NO_STRETCH_SG_PARAMS # Get grid info and regrid if necessary [refres, refgridtype, devres, devgridtype, cmpres, cmpgridtype, regridref, regriddev, regridany, refgrid, devgrid, cmpgrid, refregridder, devregridder, refregridder_list, devregridder_list] = create_regridders( refdata, devdata, weightsdir, cmpres=cmpres, sg_ref_params=sg_ref_params, sg_dev_params=sg_dev_params ) # ============================================================== # Handle grid extents for lat-lon grids # ============================================================== # Get lat/lon extents, if applicable refminlon, refmaxlon, refminlat, refmaxlat = get_grid_extents(refgrid) devminlon, devmaxlon, devminlat, devmaxlat = get_grid_extents(devgrid) if -1000 not in extent: cmpminlon, cmpmaxlon, cmpminlat, cmpmaxlat = extent else: # Account for 0-360 coordinate scale uniform_refminlon, uniform_refmaxlon = refminlon, refmaxlon uniform_devminlon, uniform_devmaxlon = devminlon, devmaxlon if uniform_refmaxlon > 185: uniform_refminlon, uniform_refmaxlon = -180, 180 if uniform_devmaxlon > 185: uniform_devminlon, uniform_devmaxlon = -180, 180 cmpminlon, cmpmaxlon, cmpminlat, cmpmaxlat = \ [np.max([(uniform_refminlon+180%360)-180, uniform_devminlon]), np.min([uniform_refmaxlon, uniform_devmaxlon]), np.max([refminlat, devminlat]), np.min([refmaxlat, devmaxlat])] # Set plot bounds for non cubed-sphere regridding and plotting # Pylint says ref_extent and dev_extent are not used # -- Bob Yantosca (15 Aug 2023) #ref_extent = (refminlon, refmaxlon, refminlat, refmaxlat) #dev_extent = (devminlon, devmaxlon, devminlat, devmaxlat) cmp_extent = (cmpminlon, cmpmaxlon, cmpminlat, cmpmaxlat) # ============================================================== # Loop over all variables # ============================================================== ds_refs = [None] * n_var frac_ds_refs = [None] * n_var ds_devs = [None] * n_var frac_ds_devs = [None] * n_var for i in range(n_var): varname = varlist[i] # ============================================================== # Slice the data, allowing for no time dimension (bpch) # ============================================================== # Ref ds_refs[i] = slice_by_lev_and_time( refdata, varname, itime, ilev, flip_ref ) if diff_of_diffs: frac_ds_refs[i] = slice_by_lev_and_time( fracrefdata, varname, itime, ilev, flip_ref ) # Dev ds_devs[i] = slice_by_lev_and_time( devdata, varname, itime, ilev, flip_dev ) if diff_of_diffs: frac_ds_devs[i] = slice_by_lev_and_time( fracdevdata, varname, itime, ilev, flip_dev ) # ================================================================== # Handle units as needed # ================================================================== # Convert to ppb if units string is variation of mol/mol if data_unit_is_mol_per_mol(ds_refs[i]): ds_refs[i].values = ds_refs[i].values * 1e9 ds_refs[i].attrs["units"] = "ppb" if data_unit_is_mol_per_mol(ds_devs[i]): ds_devs[i].values = ds_devs[i].values * 1e9 ds_devs[i].attrs["units"] = "ppb" # If units string is ppbv (true for bpch data) then rename units if ds_refs[i].units.strip() == "ppbv": ds_refs[i].attrs["units"] = "ppb" if ds_devs[i].units.strip() == "ppbv": ds_devs[i].attrs["units"] = "ppb" # If units string is W/m2 (may be true for bpch data) then rename units if ds_refs[i].units.strip() == "W/m2": ds_refs[i].attrs["units"] = "W m-2" if ds_devs[i].units.strip() == "W/m2": ds_devs[i].attrs["units"] = "W m-2" # If units string is UNITLESS (may be true for bpch data) then rename # units if ds_refs[i].units.strip() == "UNITLESS": ds_refs[i].attrs["units"] = "1" if ds_devs[i].units.strip() == "UNITLESS": ds_devs[i].attrs["units"] = "1" # Compare units of ref and dev. The check_units function will throw an # error if units do not match and enforce_units is True. check_units(ds_refs[i], ds_devs[i], enforce_units) # Convert from ppb to ug/m3 if convert_to_ugm3 is passed as true if convert_to_ugm3: # Error checks: must pass met, not normalize by area, and be in ppb if refmet is None or devmet is None: msg = "Met mata ust be passed to convert units to ug/m3." raise ValueError(msg) if normalize_by_area: msg = "Normalizing by area is not allowed if plotting ug/m3" raise ValueError(msg) if ds_refs[i].units != "ppb" or ds_devs[i].units != "ppb": msg = "Units must be mol/mol if converting to ug/m3." raise ValueError(msg) # Slice air density data by lev and time # (assume same format and dimensions as refdata and devdata) ref_airden = slice_by_lev_and_time( refmet, "Met_AIRDEN", itime, ilev, False ) dev_airden = slice_by_lev_and_time( devmet, "Met_AIRDEN", itime, ilev, False ) # Get the species molecular weights from Ref & Dev metadata spc_name = varname.replace(varname.split("_")[0] + "_", "") ref_spc_mw_g = get_molwt_from_metadata(ref_metadata, spc_name) dev_spc_mw_g = get_molwt_from_metadata(dev_metadata, spc_name) # Skip if the species has no molecular weight in # both Ref & Dev species metadata if ref_spc_mw_g is None and dev_spc_mw_g is None: msg = f"Cannot convert {spc_name} to ug/m3! " msg +="no molecular weight found in Ref & Dev metadata!" continue # If only one of the species has no molecular weight # print a warning message but allow comparison to proceed if ref_spc_mw_g is None or dev_spc_mw_g is None: msg = f"Cannot convert {spc_name} to ug/m3!, " msg +="no molecular weight was found!" print(msg) # Convert values from ppb to ug/m3: # ug/m3 = 1e-9ppb * mol/g air * kg/m3 air * 1e3g/kg # * g/mol spc * 1e6ug/g # = ppb * air density * (spc MW / air MW) # # If mol. wt. is missing, then set data to NaN if ref_spc_mw_g is not None: ds_refs[i].values *= \ ref_airden.values * (ref_spc_mw_g / MW_AIR_g) else: ds_refs[i].values *= np.nan if dev_spc_mw_g is not None: ds_devs[i].values *= \ dev_airden.values * (dev_spc_mw_g / MW_AIR_g) else: ds_devs[i].values *= np.nan # Update units string ds_refs[i].attrs["units"] = "\u03BCg/m3" # ug/m3 using mu ds_devs[i].attrs["units"] = "\u03BCg/m3" # ================================================================== # Get the area variables if normalize_by_area=True. They can be # either in the main datasets as variable AREA or in the optionally # passed meteorology datasets as Met_AREAM2. # ================================================================== if normalize_by_area: # ref if "AREA" in refdata.data_vars.keys(): ref_area = refdata["AREA"] elif refmet is not None: if "Met_AREAM2" in refmet.data_vars.keys(): ref_area = refmet["Met_AREAM2"] else: msg = "normalize_by_area = True but AREA not " \ + "present in the Ref dataset and ref met with Met_AREAM2" \ + " not passed!" raise ValueError(msg) if "time" in ref_area.dims: ref_area = ref_area.isel(time=0) if refgridtype == 'cs': ref_area = reshape_MAPL_CS(ref_area) # dev if "AREA" in devdata.data_vars.keys(): dev_area = devdata["AREA"] elif devmet is not None: if "Met_AREAM2" in devmet.data_vars.keys(): dev_area = devmet["Met_AREAM2"] else: msg = "normalize_by_area = True but AREA not " \ + "present in the Dev dataset and dev met with Met_AREAM2" \ | " not passed!" raise ValueError(msg) if "time" in dev_area.dims: dev_area = dev_area.isel(time=0) if devgridtype == 'cs': dev_area = reshape_MAPL_CS(dev_area) # Make sure the areas do not have a lev dimension if "lev" in ref_area.dims: ref_area = ref_area.isel(lev=0) if "lev" in dev_area.dims: dev_area = dev_area.isel(lev=0) # ============================================================== # Reshape cubed sphere data if using MAPL v1.0.0+ # TODO: update function to expect data in this format # ============================================================== for i in range(n_var): ds_refs[i] = reshape_MAPL_CS(ds_refs[i]) ds_devs[i] = reshape_MAPL_CS(ds_devs[i]) if diff_of_diffs: frac_ds_refs[i] = reshape_MAPL_CS(frac_ds_refs[i]) frac_ds_devs[i] = reshape_MAPL_CS(frac_ds_devs[i]) # ================================================================== # Create arrays for each variable in Ref and Dev datasets # and do any necessary horizontal regridding. 'cmp' stands for comparison # and represents ref and dev data regridded as needed to a common # grid type and resolution for use in difference and ratio plots. # ================================================================== ds_ref_cmps = [None] * n_var ds_dev_cmps = [None] * n_var frac_ds_ref_cmps = [None] * n_var frac_ds_dev_cmps = [None] * n_var global_cmp_grid = call_make_grid(cmpres, cmpgridtype)[0] # Get grid limited to cmp_extent for comparison datasets # Do not do this for cross-dateline plotting if cmp_extent[0] < cmp_extent[1]: regional_cmp_extent = cmp_extent else: regional_cmp_extent = [-180, 180, -90, 90] regional_cmp_grid = call_make_grid(cmpres, cmpgridtype, in_extent=[-180,180,-90,90], out_extent=regional_cmp_extent)[0] # Get comparison data extents in same midpoint format as lat-lon grid. cmp_mid_minlon, cmp_mid_maxlon, cmp_mid_minlat, cmp_mid_maxlat = \ get_grid_extents(regional_cmp_grid, edges=False) cmpminlon_ind = np.where(global_cmp_grid["lon"] >= cmp_mid_minlon)[0][0] cmpmaxlon_ind = np.where(global_cmp_grid["lon"] <= cmp_mid_maxlon)[0][-1] cmpminlat_ind = np.where(global_cmp_grid["lat"] >= cmp_mid_minlat)[0][0] cmpmaxlat_ind = np.where(global_cmp_grid["lat"] <= cmp_mid_maxlat)[0][-1] for i in range(n_var): ds_ref = ds_refs[i] ds_dev = ds_devs[i] # Do area normalization before regridding if normalize_by_area is True. # Assumes units are the same in ref and dev. If enforce_units is passed # as false then normalization may not be correct. if normalize_by_area: exclude_list = ["WetLossConvFrac", "Prod_", "Loss_"] if not any(s in varname for s in exclude_list): ds_ref.values = ds_ref.values / ref_area.values ds_dev.values = ds_dev.values / dev_area.values ds_refs[i] = ds_ref ds_devs[i] = ds_dev if diff_of_diffs: frac_ds_refs[i] = frac_ds_refs[i].values / ref_area.values frac_ds_devs[i] = frac_ds_devs[i].values / dev_area.values ref_cs_res = refres dev_cs_res = devres if cmpgridtype == "cs": ref_cs_res = cmpres dev_cs_res = cmpres # Ref ds_ref_cmps[i] = regrid_comparison_data( ds_ref, ref_cs_res, regridref, refregridder, refregridder_list, global_cmp_grid, refgridtype, cmpgridtype, cmpminlat_ind, cmpmaxlat_ind, cmpminlon_ind, cmpmaxlon_ind ) # Dev ds_dev_cmps[i] = regrid_comparison_data( ds_dev, dev_cs_res, regriddev, devregridder, devregridder_list, global_cmp_grid, devgridtype, cmpgridtype, cmpminlat_ind, cmpmaxlat_ind, cmpminlon_ind, cmpmaxlon_ind ) # Diff of diffs if diff_of_diffs: frac_ds_ref_cmps[i] = regrid_comparison_data( frac_ds_refs[i], ref_cs_res, regridref, refregridder, refregridder_list, global_cmp_grid, refgridtype, cmpgridtype, cmpminlat_ind, cmpmaxlat_ind, cmpminlon_ind, cmpmaxlon_ind ) frac_ds_dev_cmps[i] = regrid_comparison_data( frac_ds_devs[i], dev_cs_res, regriddev, devregridder, devregridder_list, global_cmp_grid, devgridtype, cmpgridtype, cmpminlat_ind, cmpmaxlat_ind, cmpminlon_ind, cmpmaxlon_ind ) # Force garbage collection manually (frees memory) del refregridder, refregridder_list, devregridder, devregridder_list gc.collect() # ================================================================= # Define function to create a single page figure to be called # in a parallel loop # ================================================================= def createfig(ivar, temp_dir=''): # Suppress harmless run-time warnings (mostly about underflow) warnings.filterwarnings('ignore', category=RuntimeWarning) warnings.filterwarnings('ignore', category=UserWarning) if savepdf and verbose: print(f"{ivar} ", end="") varname = varlist[ivar] ds_ref = ds_refs[ivar] ds_dev = ds_devs[ivar] # ============================================================== # Set units and subtitle, including modification if normalizing # area. Note if enforce_units is False (non-default) then # units on difference plots will be wrong. # ============================================================== cmn_units = ds_ref.attrs["units"] subtitle_extra = "" if normalize_by_area: exclude_list = ["WetLossConvFrac", "Prod_", "Loss_"] if not any(s in varname for s in exclude_list): if "/" in cmn_units: cmn_units = f"{cmn_units}/m2" else: cmn_units = f"{cmn_units} m-2" ds_ref.attrs["units"] = cmn_units ds_dev.attrs["units"] = cmn_units subtitle_extra = ", Normalized by Area" # ============================================================== # Get comparison data sets, regridding input slices if needed # ============================================================== # Initialize objects to avoid Pylint warnings ds_ref_cmp_reshaped = xr.Dataset() ds_dev_cmp_reshaped = xr.Dataset() frac_ds_ref_cmp_reshaped = xr.Dataset() frac_ds_dev_cmp_reshaped = xr.Dataset() # Reshape ref/dev cubed sphere data, if any ds_ref_reshaped = None if refgridtype == "cs": ds_ref_reshaped = ds_ref.data.reshape(6, refres, refres) ds_dev_reshaped = None if devgridtype == "cs": ds_dev_reshaped = ds_dev.data.reshape(6, devres, devres) ds_ref_cmp = ds_ref_cmps[ivar] ds_dev_cmp = ds_dev_cmps[ivar] frac_ds_ref_cmp = frac_ds_ref_cmps[ivar] frac_ds_dev_cmp = frac_ds_dev_cmps[ivar] # Reshape comparison cubed sphere data, if any if cmpgridtype == "cs": def call_reshape(cmp_data): new_data = None if isinstance(cmp_data, xr.DataArray): new_data = cmp_data.data.reshape(6, cmpres, cmpres) elif isinstance(cmp_data, np.ndarray): new_data = cmp_data.reshape(6, cmpres, cmpres) return new_data ds_ref_cmp_reshaped = call_reshape(ds_ref_cmp) ds_dev_cmp_reshaped = call_reshape(ds_dev_cmp) frac_ds_ref_cmp_reshaped = call_reshape(frac_ds_ref_cmp) frac_ds_dev_cmp_reshaped = call_reshape(frac_ds_dev_cmp) # ============================================================== # Get min and max values for use in the top-row plot colorbars # and also flag if Ref and/or Dev are all zero or all NaN. # ============================================================== # Choose from values within plot extent if -1000 not in extent: min_max_extent = extent else: min_max_extent = cmp_extent # Find min and max lon min_max_minlon = np.min([min_max_extent[0], min_max_extent[1]]) min_max_maxlon = np.max([min_max_extent[0], min_max_extent[1]]) min_max_minlat = min_max_extent[2] min_max_maxlat = min_max_extent[3] def get_extent_for_colors(dset, minlon, maxlon, minlat, maxlat): ds_new = dset.copy() lat_var='lat' lon_var='lon' # Account for cubed-sphere data if 'lons' in ds_new.coords: lat_var='lats' lon_var='lons' if ds_new['lon'].max() > 190: minlon=minlon%360 maxlon=maxlon%360 # account for global plot if minlon == maxlon and maxlon == 180: minlon = 0 maxlon = 360 # account for cross dateline if minlon > maxlon: minlon, maxlon = maxlon, minlon # Add .compute() to force evaluation of ds_new[lon_var] # See https://github.com/geoschem/gcpy/issues/254 # Also note: This may return as a dask.array.Array object return ds_new.where(\ ds_new[lon_var].compute() >= minlon, drop=True).\ where(ds_new[lon_var].compute() <= maxlon, drop=True).\ where(ds_new[lat_var].compute() >= minlat, drop=True).\ where(ds_new[lat_var].compute() <= maxlat, drop=True) ds_ref_reg = get_extent_for_colors( ds_ref, min_max_minlon, min_max_maxlon, min_max_minlat, min_max_maxlat ) ds_dev_reg = get_extent_for_colors( ds_dev, min_max_minlon, min_max_maxlon, min_max_minlat, min_max_maxlat ) # Use global data to determine cbar bounds if plotting cubed-sphere if refgridtype == "cs": vmin_ref = float(np.nanmin(ds_ref.data)) vmax_ref = float(np.nanmax(ds_ref.data)) else: vmin_ref = float(np.nanmin(ds_ref_reg.data)) vmax_ref = float(np.nanmax(ds_ref_reg.data)) if devgridtype == "cs": vmin_dev = float(np.nanmin(ds_dev.data)) vmax_dev = float(np.nanmax(ds_dev.data)) else: vmin_dev = float(np.nanmin(ds_dev_reg.data)) vmax_dev = float(np.nanmax(ds_dev_reg.data)) # Set vmin_both and vmax_both to use if match_cbar=True vmin_both = np.nanmin([vmin_ref, vmin_dev]) vmax_both = np.nanmax([vmax_ref, vmax_dev]) # ============================================================== # Test if Ref and/or Dev contain all zeroes or all NaNs. # This will have implications as to how we set min and max # values for the color ranges below. # ============================================================== ref_is_all_zero, ref_is_all_nan = all_zero_or_nan(ds_ref.values) dev_is_all_zero, dev_is_all_nan = all_zero_or_nan(ds_dev.values) # ============================================================== # Calculate absolute difference # ============================================================== if cmpgridtype == "ll": absdiff = np.array(ds_dev_cmp) - np.array(ds_ref_cmp) else: absdiff = ds_dev_cmp_reshaped - ds_ref_cmp_reshaped # Test if the abs. diff. is zero everywhere or NaN everywhere absdiff_is_all_zero, absdiff_is_all_nan = all_zero_or_nan(absdiff) # For cubed-sphere, take special care to avoid a spurious # boundary line, as described here: https://stackoverflow.com/ # questions/46527456/preventing-spurious-horizontal-lines-for- # ungridded-pcolormesh-data if cmpgridtype == "cs": absdiff = np.ma.masked_where(np.abs(cmpgrid["lon"] - 180) < 2, absdiff) # ============================================================== # Calculate fractional difference, set divides by zero to NaN # ============================================================== if cmpgridtype == "ll": # Replace fractional difference plots with absolute difference # of fractional datasets if necessary if frac_ds_dev_cmp is not None and frac_ds_ref_cmp is not None: fracdiff = np.array(frac_ds_dev_cmp) - \ np.array(frac_ds_ref_cmp) else: fracdiff = np.abs(np.array(ds_dev_cmp)) / \ np.abs(np.array(ds_ref_cmp)) else: if frac_ds_dev_cmp is not None and frac_ds_ref_cmp is not None: fracdiff = frac_ds_dev_cmp_reshaped - \ frac_ds_ref_cmp_reshaped else: fracdiff = np.abs(ds_dev_cmp_reshaped) / \ np.abs(ds_ref_cmp_reshaped) # Replace Infinity values with NaN fracdiff = np.where(np.abs(fracdiff) == np.inf, np.nan, fracdiff) fracdiff[np.abs(fracdiff > 1e308)] = np.nan # Test if the frac. diff. is zero everywhere or NaN everywhere fracdiff_is_all_zero = not np.any(fracdiff) or \ (np.nanmin(fracdiff) == 0 and np.nanmax(fracdiff) == 0) fracdiff_is_all_nan = np.isnan(fracdiff).all() or ref_is_all_zero # For cubed-sphere, take special care to avoid a spurious # boundary line, as described here: https://stackoverflow.com/ # questions/46527456/preventing-spurious-horizontal-lines-for- # ungridded-pcolormesh-data if cmpgridtype == "cs": fracdiff = np.ma.masked_where(np.abs(cmpgrid["lon"] - 180) < 2, fracdiff) # ============================================================== # Create 3x2 figure # ============================================================== # Create figures and axes objects # Also define the map projection that will be shown if extent[0] > extent[1]: proj = ccrs.PlateCarree(central_longitude=180) else: proj = ccrs.PlateCarree() figs, ((ax0, ax1), (ax2, ax3), (ax4, ax5)) = plt.subplots( 3, 2, figsize=[12, 14], subplot_kw={"projection": proj} ) # Ensure subplots don't overlap when invoking plt.show() if not savepdf: plt.subplots_adjust(hspace=0.4) # Give the figure a title offset = 0.96 if "lev" in ds_ref.dims and "lev" in ds_dev.dims: if ilev == 0: levstr = "Surface" elif ilev == 22: levstr = "500 hPa" else: levstr = "Level " + str(ilev - 1) if extra_title_txt is not None: figs.suptitle( f"{varname}, {levstr} ({extra_title_txt})", y=offset, ) else: figs.suptitle( f"{varname}, {levstr}", y=offset ) elif ( "lat" in ds_ref.dims and "lat" in ds_dev.dims and "lon" in ds_ref.dims and "lon" in ds_dev.dims ): if extra_title_txt is not None: figs.suptitle( f"{varname} ({extra_title_txt})", y=offset, ) else: figs.suptitle( f"{varname}", y=offset) else: print(f"Incorrect dimensions for {varname}!") # ============================================================== # Set colormaps for data plots # # Use shallow copy (copy.copy() to create color map objects, # in order to avoid set_bad() from being applied to the base # color table. See: https://docs.python.org/3/library/copy.html # ============================================================== # Colormaps for 1st row (Ref and Dev) if use_cmap_RdBu: cmap_toprow_nongray = copy.copy(mpl.colormaps["RdBu_r"]) cmap_toprow_gray = copy.copy(mpl.colormaps["RdBu_r"]) else: cmap_toprow_nongray = copy.copy(WhGrYlRd) cmap_toprow_gray = copy.copy(WhGrYlRd) cmap_toprow_gray.set_bad(color="gray") if refgridtype == "ll": if ref_is_all_nan: ref_cmap = cmap_toprow_gray else: ref_cmap = cmap_toprow_nongray if dev_is_all_nan: dev_cmap = cmap_toprow_gray else: dev_cmap = cmap_toprow_nongray # Colormaps for 2nd row (Abs. Diff.) and 3rd row (Frac. Diff,) cmap_nongray = copy.copy(mpl.colormaps["RdBu_r"]) cmap_gray = copy.copy(mpl.colormaps["RdBu_r"]) cmap_gray.set_bad(color="gray") # ============================================================== # Set titles for plots # ============================================================== if refgridtype == "ll": ref_title = f"{refstr} (Ref){subtitle_extra}\n{refres}" else: ref_title = f"{refstr} (Ref){subtitle_extra}\nc{refres}" if devgridtype == "ll": dev_title = f"{devstr} (Dev){subtitle_extra}\n{devres}" else: dev_title = f"{devstr} (Dev){subtitle_extra}\nc{devres}" if regridany: absdiff_dynam_title = \ f"Difference ({cmpres})\nDev - Ref, Dynamic Range" absdiff_fixed_title = \ f"Difference ({cmpres})\nDev - Ref, Restricted Range [5%,95%]" if diff_of_diffs: fracdiff_dynam_title = \ f"Difference ({cmpres}), " + \ f"Dynamic Range\n{frac_devstr} - {frac_refstr}" fracdiff_fixed_title = \ f"Difference ({cmpres}), " + \ f"Restricted Range [5%,95%]\n{frac_devstr} - {frac_refstr}" else: fracdiff_dynam_title = \ f"Ratio ({cmpres})\nDev/Ref, Dynamic Range" fracdiff_fixed_title = \ f"Ratio ({cmpres})\nDev/Ref, Fixed Range" else: absdiff_dynam_title = "Difference\nDev - Ref, Dynamic Range" absdiff_fixed_title = \ "Difference\nDev - Ref, Restricted Range [5%,95%]" if diff_of_diffs: fracdiff_dynam_title = \ f"Difference, Dynamic Range\n{frac_devstr} - {frac_refstr}" fracdiff_fixed_title = \ "Difference, Restricted Range " + \ f"[5%,95%]\n{frac_devstr} - {frac_refstr}" else: fracdiff_dynam_title = "Ratio \nDev/Ref, Dynamic Range" fracdiff_fixed_title = "Ratio \nDev/Ref, Fixed Range" # ============================================================== # Bundle variables for 6 parallel plotting calls # 0 = Ref 1 = Dev # 2 = Dynamic abs diff 3 = Restricted abs diff # 4 = Dynamic frac diff 5 = Restricted frac diff # ============================================================== subplots = six_panel_subplot_names(diff_of_diffs) all_zeros = [ ref_is_all_zero, dev_is_all_zero, absdiff_is_all_zero, absdiff_is_all_zero, fracdiff_is_all_zero, fracdiff_is_all_zero, ] all_nans = [ ref_is_all_nan, dev_is_all_nan, absdiff_is_all_nan, absdiff_is_all_nan, fracdiff_is_all_nan, fracdiff_is_all_nan, ] if -1000 not in extent: extents = [extent[:], extent[:], extent[:], extent[:], extent[:], extent[:]] else: plot_extent = [np.max([cmp_extent[0], -180]), np.min([cmp_extent[1], 180]), cmp_extent[2], cmp_extent[3]] extents = [plot_extent[:], plot_extent[:], plot_extent[:], plot_extent[:], plot_extent[:], plot_extent[:]] plot_vals = [ds_ref, ds_dev, absdiff, absdiff, fracdiff, fracdiff] grids = [refgrid, devgrid, regional_cmp_grid.copy(), regional_cmp_grid.copy(), regional_cmp_grid.copy(), regional_cmp_grid.copy()] axs = [ax0, ax1, ax2, ax3, ax4, ax5] rowcols = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] titles = [ ref_title, dev_title, absdiff_dynam_title, absdiff_fixed_title, fracdiff_dynam_title, fracdiff_fixed_title, ] if refgridtype == "ll": cmaps = [ref_cmap, dev_cmap, cmap_gray, cmap_gray, cmap_gray, cmap_gray] else: cmaps = [ cmap_toprow_nongray, cmap_toprow_nongray, cmap_nongray, cmap_nongray, cmap_nongray, cmap_nongray, ] ref_masked = None dev_masked = None if refgridtype == "cs": ref_masked = np.ma.masked_where( np.abs(refgrid["lon"] - 180) < 2, ds_ref_reshaped ) if devgridtype == "cs": dev_masked = np.ma.masked_where( np.abs(devgrid["lon"] - 180) < 2, ds_dev_reshaped ) masked = [ref_masked, dev_masked, absdiff, absdiff, fracdiff, fracdiff] gridtypes = [ refgridtype, devgridtype, cmpgridtype, cmpgridtype, cmpgridtype, cmpgridtype, ] unit_list = [ds_ref.units, ds_dev.units, cmn_units, cmn_units, "unitless", "unitless"] other_all_nans = [dev_is_all_nan, ref_is_all_nan, False, False, False, False] mins = [vmin_ref, vmin_dev, vmin_both] maxs = [vmax_ref, vmax_dev, vmax_both] ratio_logs = [False, False, False, False, True, True] # Plot for i in range(6): six_plot( subplots[i], all_zeros[i], all_nans[i], plot_vals[i], grids[i], axs[i], rowcols[i], titles[i], cmaps[i], unit_list[i], extents[i], masked[i], other_all_nans[i], gridtypes[i], mins, maxs, use_cmap_RdBu, match_cbar, verbose, log_color_scale, plot_type="single_level", ratio_log=ratio_logs[i], proj=proj, ll_plot_func=ll_plot_func, **extra_plot_args ) # ============================================================== # Add this page of 6-panel plots to a PDF file # ============================================================== if savepdf: folders = pdfname.split('/') pdfname_temp = folders[-1] + "BENCHMARKFIGCREATION.pdf" + str(ivar) full_path = temp_dir for folder in folders[:-1]: full_path = os.path.join(full_path, folder) if not os.path.isdir(full_path): try: os.mkdir(full_path) except FileExistsError: pass pdf = PdfPages(os.path.join(full_path, pdfname_temp)) pdf.savefig(figs) pdf.close() plt.close(figs) # ============================================================== # Update the list of variables with significant differences. # Criterion: abs(1 - max(fracdiff)) > 0.1 # Do not include NaNs in the criterion, because these indicate # places where fracdiff could not be computed (div-by-zero). # ============================================================== if np.abs(1 - np.nanmax(fracdiff)) > 0.1: sigdiff_list.append(varname) return varname return "" # ================================================================== # Call figure generation function in a parallel loop over variables # ================================================================== # do not attempt nested thread parallelization due to issues with # matplotlib if current_process().name != "MainProcess": n_job = 1 if not savepdf: # disable parallel plotting to allow interactive figure plotting for i in range(n_var): createfig(i) else: with TemporaryDirectory() as temp_dir: # --------------------------------------- # Turn off parallelization if n_job=1 if n_job != 1: results = Parallel(n_jobs=n_job)( delayed(createfig)(i, temp_dir) for i in range(n_var) ) else: results = [] for i in range(n_var): results.append(createfig(i, temp_dir)) # --------------------------------------- # update sig diffs after parallel calls if current_process().name == "MainProcess": for varname in results: if isinstance(varname, str): sigdiff_list.append(varname) # ========================================================== # Finish # ========================================================== # Close the PDF object pdf = PdfPages(pdfname) pdf.close() if verbose: print("Closed PDF") # Concatenate individual PDFs together # Now use PdfWriter instead of PdfMerger writer = PdfWriter() for i in range(n_var): temp_pdfname = pdfname if pdfname.startswith('/'): temp_pdfname = temp_pdfname[1:] temp_pdfname = os.path.join( str(temp_dir), f"{temp_pdfname}BENCHMARKFIGCREATION.pdf{str(i)}" ) reader = PdfReader(temp_pdfname) for page in reader.pages: writer.add_page(page) # Write combined PDF with open(pdfname, "wb") as ofile: writer.write(ofile) if verbose: print(f"Created {pdfname} for {n_var} variables") warnings.showwarning = _warning_format