Source code for gcpy.plot.single_panel

"""
Creates a single panel plot (geographic map or zonal mean).
"""
import copy
from matplotlib import ticker
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
from dask.array import Array as DaskArray
import xarray as xr
import cartopy.crs as ccrs
from gcpy.grid import get_vert_grid, get_pressure_indices, \
    pad_pressure_edges, convert_lev_to_pres, get_grid_extents, \
    call_make_grid, get_input_res
from gcpy.regrid import regrid_comparison_data, create_regridders
from gcpy.util import reshape_MAPL_CS, all_zero_or_nan, verify_variable_type
from gcpy.plot.core  import gcpy_style, normalize_colors, WhGrYlRd

# Suppress numpy divide by zero warnings to prevent output spam
np.seterr(divide="ignore", invalid="ignore")

# Use a style sheet to control plot attributes
plt.style.use(gcpy_style)


[docs] def single_panel( plot_vals, ax=None, plot_type="single_level", grid=None, gridtype="", title="fill", comap=WhGrYlRd, norm=None, unit="", extent=None, masked_data=None, use_cmap_RdBu=False, log_color_scale=False, add_cb=True, pres_range=None, pedge=np.full((1, 1), -1), pedge_ind=np.full((1, 1), -1), log_yaxis=False, xtick_positions=None, xticklabels=None, proj=ccrs.PlateCarree(), sg_path='', ll_plot_func="imshow", vert_params=None, pdfname="", weightsdir='.', vmin=None, vmax=None, return_list_of_plots=False, **extra_plot_args ): """ Core plotting routine -- creates a single plot panel. Parameters ---------- plot_vals : xarray.DataArray or numpy.ndarray or dask.array.Array Single data variable GEOS-Chem output to plot. ax : matplotlib.axes.Axes, optional Axes object to plot information. Default value: None (will create a new axes) plot_type : str, optional Either "single_level" or "zonal_mean". Default value: "single_level" grid : dict, optional Dictionary mapping plot_vals to plottable coordinates. Default value: {} (will attempt to read grid from plot_vals) gridtype : str, optional "ll" for lat/lon or "cs" for cubed-sphere. Default value: "" (will automatically determine from grid) title : str, optional Title to put at top of plot. Default value: "fill" (will use name attribute of plot_vals if available) comap : matplotlib.colors.Colormap, optional Colormap for plotting data values. Default value: WhGrYlRd norm : list, optional List with range [0..1] normalizing color range for matplotlib methods. Default value: None (will determine from plot_vals) unit : str, optional Units of plotted data. Default value: "" (will use units attribute of plot_vals if available) extent : tuple of float, optional Describes minimum and maximum latitude and longitude of input data in the form (minlon, maxlon, minlat, maxlat). Default value: None (will use full extent of plot_vals if plot is single level) masked_data : numpy.ndarray, optional Masked area for avoiding near-dateline cubed-sphere plotting issues. Default value: None (will attempt to determine from plot_vals) use_cmap_RdBu : bool, optional Set this flag to True to use a blue-white-red colormap. Default value: False log_color_scale : bool, optional Set this flag to True to use a log-scale colormap. Default value: False add_cb : bool, optional Set this flag to True to add a colorbar to the plot. Default value: True pres_range : list of int, optional Range from minimum to maximum pressure for zonal mean plotting. Default value: [0, 2000] (will plot entire atmosphere) pedge : numpy.ndarray, optional Edge pressures of vertical grid cells in plot_vals for zonal mean plotting. Default value: np.full((1, 1), -1) (will determine automatically) pedge_ind : numpy.ndarray, optional Index of edge pressure values within pressure range in plot_vals for zonal mean plotting. Default value: np.full((1, 1), -1) (will determine automatically) log_yaxis : bool, optional Set this flag to True to enable log scaling of pressure in zonal mean plots. Default value: False xtick_positions : list of float, optional Locations of lat/lon or lon ticks on plot. Default value: None (will place automatically for zonal mean plots) xticklabels : list of str, optional Labels for lat/lon ticks. Default value: None (will determine automatically from xtick_positions) proj : cartopy.crs.Projection, optional Projection for plotting data. Default value: ccrs.PlateCarree() sg_path : str, optional Path to NetCDF file containing stretched-grid info (in attributes) for plot_vals. Default value: '' (will not be read in) ll_plot_func : str, optional Function to use for lat/lon single level plotting with possible values 'imshow' and 'pcolormesh'. imshow is much faster but is slightly displaced when plotting from dateline to dateline and/or pole to pole. Default value: 'imshow' vert_params : list of array-like, optional Hybrid grid parameter A in hPa and B (unitless). Needed if grid is not 47 or 72 levels. Default value: None pdfname : str, optional File path to save plots as PDF. Default value: "" (will not create PDF) weightsdir : str, optional Directory path for storing regridding weights. Default value: "." (will store regridding files in current directory) vmin : float, optional Minimum for colorbars. Default value: None (will use plot value minimum) vmax : float, optional Maximum for colorbars. Default value: None (will use plot value maximum) return_list_of_plots : bool, optional Return plots as a list. This is helpful if you are using a cubed-sphere grid and would like access to all 6 plots. Default value: False **extra_plot_args Any extra keyword arguments are passed to calls to pcolormesh() (CS) or imshow() (Lat/Lon). Returns ------- plot : matplotlib.artist.Artist Plot object created from input. """ verify_variable_type(plot_vals, (xr.DataArray, np.ndarray, DaskArray)) # Create empty lists for keyword arguments if pres_range is None: pres_range = [0, 2000] if vert_params is None: vert_params = [[], []] # Eliminate 1D level or time dimensions plot_vals = plot_vals.squeeze() data_is_xr = isinstance(plot_vals, xr.DataArray) if xtick_positions is None: xtick_positions = [] if plot_type == "zonal_mean": xtick_positions = np.arange(-90, 90, 30) if xticklabels is None: xticklabels = [rf"{x}$\degree$" for x in xtick_positions] if unit == "" and data_is_xr: try: unit = plot_vals.units.strip() except BaseException: pass if title == "fill" and data_is_xr: try: title = plot_vals.name except BaseException: pass # Generate grid if not passed if grid is None: res, gridtype = get_input_res(plot_vals) sg_params = [1, 170, -90] if sg_path != '': sg_attrs = xr.open_dataset(sg_path).attrs sg_params = [ sg_attrs['stretch_factor'], sg_attrs['target_longitude'], sg_attrs['target_latitude']] if plot_type == 'single_level': grid_extent = get_grid_extents(plot_vals) [grid, _] = call_make_grid( res, gridtype, in_extent=grid_extent, sg_params=sg_params ) else: # zonal mean if np.all(pedge_ind == -1) or np.all(pedge == -1): # Get mid-point pressure and edge pressures for this grid pedge, pmid, _ = get_vert_grid(plot_vals, *vert_params) # Get indexes of pressure subrange (full range is default) pedge_ind = get_pressure_indices(pedge, pres_range) # Pad edges if subset does not include surface or TOA so data spans # entire subrange pedge_ind = pad_pressure_edges( pedge_ind, plot_vals.sizes["lev"], len(pmid)) # pmid indexes do not include last pedge index pmid_ind = pedge_ind[:-1] # Convert levels to pressures in ref and dev data plot_vals = convert_lev_to_pres(plot_vals, pmid, pedge) # get proper levels plot_vals = plot_vals.isel(lev=pmid_ind) [input_res, input_gridtype, _, _, _, new_gridtype, regrid, _, _, _, _, grid, regridder, _, regridder_list, _] = create_regridders( plot_vals, plot_vals, weightsdir=weightsdir, cmpres=None, zm=True, sg_ref_params=sg_params ) if gridtype == 'cs': plot_vals = reshape_MAPL_CS(plot_vals) nlev = len(plot_vals['lev']) # Ref plot_vals = regrid_comparison_data( plot_vals, input_res, regrid, regridder, regridder_list, grid, input_gridtype, new_gridtype, nlev=nlev ) # average across longitude bands # assume lon dim is index 2 (no time dim) if a numpy array is passed lon_ind = 2 if isinstance(plot_vals, xr.DataArray): lon_ind = plot_vals.dims.index('lon') # calculate zonal means plot_vals = plot_vals.mean(axis=lon_ind) if gridtype == "": _, gridtype = get_input_res(plot_vals) if extent is None or extent == (None, None, None, None): extent = get_grid_extents(grid) # convert to -180 to 180 grid if needed (necessary if going # cross-dateline later) if extent[0] > 180 or extent[1] > 180: #extent = [((extent[0]+180)%360)-180, ((extent[1]+180)%360)-180, extent[2], extent[3]] extent = [extent[0] - 180, extent[1] - 180, extent[2], extent[3]] #''' #if extent[0] < -180 and 'x' in res: # lon_res = float(res.split('x')[1]) # extent = [180, #if extent[1] > 180 and 'x' in res: # extent[1] = 180 #''' # Account for cross-dateline extent if extent[0] > extent[1]: if gridtype == "ll": # rearrange data with dateline in the middle instead of prime meridian # change extent / grid to where dateline is 0, prime meridian is -180 / 180 # needed for numpy arrays if doing pcolormesh / imshow, and xarray DataArrays # if using imshow proj = ccrs.PlateCarree(central_longitude=180) if ll_plot_func == "imshow" or \ not isinstance(plot_vals, xr.DataArray): i = 0 while grid['lon_b'][i] < 0: i = i+1 plot_vals_holder = copy.deepcopy(plot_vals) if not isinstance(plot_vals, xr.DataArray): plot_vals_holder[:,:-i] = plot_vals[:,i:] plot_vals_holder[:,-i:] = plot_vals[:,:i] else: plot_vals_holder.values[:,:-i] = plot_vals.values[:,i:] plot_vals_holder.values[:,-i:] = plot_vals.values[:,:i] plot_vals = plot_vals_holder extent[0] = extent[0] % 360 - 180 extent[1] = extent[1] % 360 - 180 grid["lon_b"] = grid["lon_b"] % 360 - 180 grid["lon"] = grid["lon"] % 360 - 180 if isinstance(plot_vals, xr.DataArray): plot_vals['lon'] = plot_vals['lon'] % 360 - 180 # realign grid also if doing imshow or using numpy arrays if ll_plot_func == "imshow" or \ not isinstance(plot_vals, xr.DataArray): temp_grid = copy.deepcopy(grid) temp_grid['lon_b'][:-i] = grid['lon_b'][i:] temp_grid['lon_b'][-i:] = grid['lon_b'][:i] temp_grid['lon'][:-i] = grid['lon'][i:] temp_grid['lon'][-i:] = grid['lon'][:i] grid = temp_grid if isinstance(plot_vals, xr.DataArray): plot_vals = plot_vals.assign_coords({'lon' : grid['lon']}) if gridtype == "cs": proj = ccrs.PlateCarree(central_longitude=180) extent[0] = extent[0] % 360 - 180 extent[1] = extent[1] % 360 - 180 grid["lon_b"] = grid["lon_b"] % 360 - 180 grid["lon"] = grid["lon"] % 360 - 180 if ax is None: if plot_type == "zonal_mean": ax = plt.axes() if plot_type == "single_level": ax = plt.axes(projection=proj) fig = plt.gcf() data_is_xr = isinstance(plot_vals, xr.DataArray) # Normalize colors (put into range [0..1] for matplotlib methods) if norm is None: if data_is_xr: vmin = plot_vals.data.min() if vmin is None else vmin vmax = plot_vals.data.max() if vmax is None else vmax elif isinstance(plot_vals, np.ndarray): vmin = np.min(plot_vals) if vmin is None else vmin vmax = np.max(plot_vals) if vmax is None else vmax norm = normalize_colors( vmin, vmax, is_difference=use_cmap_RdBu, log_color_scale=log_color_scale) # Create plot ax.set_title(title) if plot_type == "zonal_mean": # Zonal mean plot plot = ax.pcolormesh( grid["lat_b"], pedge[pedge_ind], plot_vals, cmap=comap, norm=norm, **extra_plot_args) ax.set_aspect("auto") ax.set_ylabel("Pressure (hPa)") if log_yaxis: ax.set_yscale("log") ax.yaxis.set_major_formatter( ticker.FuncFormatter(lambda y, _: f"{y:g}") ) ax.invert_yaxis() ax.set_xticks(xtick_positions) ax.set_xticklabels(xticklabels) elif gridtype == "ll": if ll_plot_func == 'imshow': # Lat/Lon single level [minlon, maxlon, minlat, maxlat] = extent # expand extent to minimize imshow distortion #[dlat,dlon] = list(map(float, res.split('x'))) dlon = grid['lon'][2] - grid['lon'][1] dlat = grid['lat'][2] - grid['lat'][1] def get_nearest_extent(val, array, direction, spacing): # choose nearest values in grid to desired extent to minimize distortion grid_vals = np.asarray(array) diff = grid_vals - val if direction == 'greater': diff[diff < 0] = np.inf i = diff.argmin() if diff[i] == np.inf: # expand extent to value beyond grid limits if extent # is already > max grid value return grid_vals[(np.abs(grid_vals - val)).argmin()] return grid_vals[i] # if direction is not "greater": diff[diff > 0] = -np.inf i = diff.argmax() if diff[i] == -np.inf: # expand extent to value beyond grid limits if # extent is already < min grid value # plot will be distorted if full global to avoid # cartopy issues return grid_vals[( np.abs(grid_vals - val)).argmin()] - spacing return max(grid_vals[i], -180) closest_minlon = get_nearest_extent( minlon, grid['lon_b'], 'less', dlon) closest_maxlon = get_nearest_extent( maxlon, grid['lon_b'], 'greater', dlon) # don't adjust if extent includes poles where points are not evenly # spaced anyway if np.abs( grid['lat_b'][0] - grid['lat_b'][1]) != np.abs( grid['lat_b'][1] - grid['lat_b'][2]) and minlat < grid['lat_b'][1]: closest_minlat = grid['lat_b'][0] else: closest_minlat = get_nearest_extent( minlat, grid['lat_b'], 'less', dlat) if np.abs(grid['lat_b'][-1] - grid['lat_b'][-2]) != \ np.abs(grid['lat_b'][-2] - grid['lat_b'][-3]) and \ maxlat > grid['lat_b'][-2]: closest_maxlat = grid['lat_b'][-1] else: closest_maxlat = get_nearest_extent( maxlat, grid['lat_b'], 'greater', dlat) extent = [ closest_minlon, closest_maxlon, closest_minlat, closest_maxlat] if isinstance(plot_vals, xr.DataArray): # filter data by bounds of extent plot_vals = plot_vals.where( plot_vals.lon > closest_minlon, drop=True).where( plot_vals.lon < closest_maxlon, drop=True).where( plot_vals.lat > minlat, drop=True).where( plot_vals.lat < maxlat, drop=True) else: # filter data by indices of grid minlon_i = np.where(grid['lon_b']==closest_minlon)[0] if len(minlon_i) == 0: minlon_i = 0 else: minlon_i = int(minlon_i[0]) maxlon_i = np.where(grid['lon_b']==closest_maxlon)[0] if len(maxlon_i) == 0: maxlon_i = -1 else: maxlon_i = int(maxlon_i[0]) minlat_i = np.where(grid['lat_b']==closest_minlat)[0] if len(minlat_i) == 0: minlat_i = 0 else: minlat_i = int(minlat_i[0]) maxlat_i = np.where(grid['lat_b']==closest_maxlat)[0] if len(maxlat_i) == 0: maxlat_i = -1 else: maxlat_i = int(maxlat_i[0]) plot_vals = plot_vals[minlat_i:maxlat_i+1, minlon_i:maxlon_i+1] # Create a lon/lat plot plot = ax.imshow( plot_vals, extent=extent, transform=proj, cmap=comap, norm=norm, origin='lower', interpolation='nearest', **extra_plot_args ) else: plot = ax.pcolormesh( grid["lon_b"], grid["lat_b"], plot_vals, transform=proj, cmap=comap, norm=norm, **extra_plot_args ) ax.set_extent(extent, crs=proj) ax.coastlines() ax.set_xticks(xtick_positions) ax.set_xticklabels(xticklabels) else: # Cubed-sphere single level try: if masked_data is None: masked_data = np.ma.masked_where( np.abs( grid["lon"] - 180) < 2, plot_vals.data.reshape( 6, res, res)) except ValueError: # Comparison of numpy arrays throws errors pass [minlon, maxlon, minlat, maxlat] = extent # Catch issue with plots extending into both the western and eastern # hemisphere if np.max(grid["lon_b"] > 180): grid["lon_b"] = (((grid["lon_b"] + 180) % 360) - 180) plots = [] for j in range(6): plot = ax.pcolormesh( grid["lon_b"][j, :, :], grid["lat_b"][j, :, :], masked_data[j, :, :], transform=proj, cmap=comap, norm=norm, **extra_plot_args ) plots.append(plot) ax.set_extent(extent, crs=proj) ax.coastlines() ax.set_xticks(xtick_positions) ax.set_xticklabels(xticklabels) if add_cb: cbar = plt.colorbar(plot, ax=ax, orientation="horizontal", pad=0.10) cbar.mappable.set_norm(norm) if data_is_xr: all_zero, all_nan = all_zero_or_nan(plot_vals.values) else: all_zero, all_nan = all_zero_or_nan(plot_vals) if all_zero or all_nan: if use_cmap_RdBu: cbar.set_ticks([0.0]) else: cbar.set_ticks([0.5]) if all_nan: cbar.set_ticklabels(["Undefined throughout domain"]) else: cbar.set_ticklabels(["Zero throughout domain"]) else: if log_color_scale: cbar.formatter = ticker.LogFormatter(base=10) else: if (vmax - vmin) < 0.1 or (vmax - vmin) > 100: cbar.locator = ticker.MaxNLocator(nbins=4) try: cbar.formatter.set_useOffset(False) except BaseException: # not all automatically chosen colorbar formatters properly handle # the above method pass cbar.update_ticks() cbar.set_label(unit) if pdfname != "": pdf = PdfPages(pdfname) pdf.savefig(fig) pdf.close() # in some cases users may wish to get a list of all associated plots # eg. cubedsphere grids have six plots associated with them if return_list_of_plots: return plots if 'plots' in locals() else [plot] return plot