Source code for gbm.plot.model

# model.py: Plot class for spectral fits and models
#
#     Authors: William Cleveland (USRA),
#              Adam Goldstein (USRA) and
#              Daniel Kocevski (NASA)
#
#     Portions of the code are Copyright 2020 William Cleveland and
#     Adam Goldstein, Universities Space Research Association
#     All rights reserved.
#
#     Written for the Fermi Gamma-ray Burst Monitor (Fermi-GBM)
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation, either version 3 of the License, or
#     (at your option) any later version.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.
#
#     You should have received a copy of the GNU General Public License
#     along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

from .gbmplot import GbmPlot, Histo, ModelData, Collection, ModelSamples
from .lib import *
import warnings

[docs]class ModelFit(GbmPlot): """Class for plotting spectral fits. Parameters: fitter (:class:`~gbm.spectra.fitting.SpectralFitter`, optional): The spectral fitter view (str, optional): The plot view, one of 'counts', 'photon', 'energy' or 'nufnu'. Default is 'counts' resid (bool, optional): If True, plots the residuals in counts view. Default is True. **kwargs: Options to pass to :class:`~.gbmplot.GbmPlot` Attributes: ax (:class:`matplotlib.axes`): The matplotlib axes object for the plot canvas (Canvas Backend object): The plotting canvas, if set upon initialization. count_data (Collection of :class:`~.gbmplot.ModelData`): The count data plot elements count_models (Collection of :class:`~.gbmplot.Histo`): The count model plot elements fig (:class:`matplotlib.figure`): The matplotlib figure object model_spectrum (Collection of :class:`~.gbmplot.ModelSamples`): The model spectrum sample elements residuals (Collection of :class:`~gbmplot.ModelData`): The fit residual plot elements view (str): The current plot view xlim (float, float): The plotting range of the x axis. This attribute can be set. xscale (str): The scale of the x axis, either 'linear' or 'log'. This attribute can be set. ylim (float, float): The plotting range of the y axis. This attribute can be set. yscale (str): The scale of the y axis, either 'linear' or 'log'. This attribute can be set. """ # Define a list of default plotting colors to cycle through colors = '#7F3C8D,#11A579,#3969AC,#F2B701,#E73F74,#80BA5A,#E68310,#008695,#CF1C90,#f97b72,#4b4b8f,#A5AA99'.split(',') _min_y = 1e-10 def __init__(self, fitter=None, canvas=None, view='counts', resid=True, interactive=True): warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) self._figure, axes = plt.subplots(2, 1, sharex=True, sharey=False, figsize=(5.7, 6.7), dpi=100, gridspec_kw={'height_ratios': [3,1]}) plt.subplots_adjust(hspace=0) self._ax = axes[0] self._resid_ax = axes[1] self._view = view self._fitter = None self._count_models = Collection() self._count_data = Collection() self._resids = Collection() self._model_spectrum = None # plot data and/or background if set on init if fitter is not None: self.set_fit(fitter, resid=resid) if interactive: plt.ion() @property def view(self): return self._view @property def count_models(self): return self._count_models @property def count_data(self): return self._count_data @property def residuals(self): return self._resids @property def model_spectrum(self): return self._model_spectrum
[docs] def set_fit(self, fitter, resid=False): """Set the fitter. If a fitter already exists, this triggers a replot of the fit. Args: fitter (:class:`~gbm.spectra.fitting.SpectralFitter`): The spectral fitter for which a fit has been performed resid (bool, optional): If True, plot the fit residuals """ self._fitter = fitter if self._view == 'counts': self.count_spectrum() if resid: self.show_residuals() else: self.hide_residuals() elif self._view == 'photon': self.photon_spectrum() elif self._view == 'energy': self.energy_spectrum() elif self._view == 'nufnu': self.nufnu_spectrum() else: pass
[docs] def count_spectrum(self): """Plot the count spectrum fit """ self._view = 'counts' self._ax.clear() model_counts = self._fitter.model_count_spectrum() energy, chanwidths, data_counts, data_counts_err, ulmasks = \ self._fitter.data_count_spectrum() for i in range(self._fitter.num_sets): det = self._fitter.detectors[i] self._count_models.insert(det, Histo(model_counts[i], self._ax, edges_to_zero=False, color=self.colors[i], alpha=1.0, label=det)) self._count_data.insert(det, ModelData(energy[i], data_counts[i], chanwidths[i], data_counts_err[i], self._ax, ulmask=ulmasks[i], color=self.colors[i], alpha=0.7, linewidth=0.9)) self._ax.set_ylabel(r'Rate [count s$^{-1}$ keV$^{-1}$]') self._set_view() self._ax.legend()
[docs] def photon_spectrum(self, **kwargs): """Plot the photon spectrum model Args: num_samples (int, optional): The number of sample spectra. Default is 10. """ self._view = 'photon' self._plot_spectral_model(**kwargs) self._ax.set_ylabel(r'Photon Flux [ph cm$^{-2}$ s$^{-1}$ keV$^{-1}$]', fontsize=PLOTFONTSIZE)
[docs] def energy_spectrum(self, **kwargs): """Plot the energy spectrum model Args: num_samples (int, optional): The number of sample spectra. Default is 100. """ self._view = 'energy' self._plot_spectral_model(**kwargs) self._ax.set_ylabel(r'Energy Flux [ph cm$^{-2}$ s$^{-1}$]', fontsize=PLOTFONTSIZE)
[docs] def nufnu_spectrum(self, **kwargs): """Plot the nuFnu spectrum model Args: num_samples (int, optional): The number of sample spectra. Default is 100. """ self._view = 'nufnu' self._plot_spectral_model(**kwargs) self._ax.set_ylabel(r'$\nu F_\nu$ [keV ph cm$^{-2}$ s$^{-1}$]', fontsize=PLOTFONTSIZE)
[docs] def show_residuals(self, sigma=True): """Show the fit residuals Args: sigma (bool, optional): If True, plot the residuals in units of model sigma, otherwise in units of counts. Default is True. """ # if we don't already have residuals axis if len(self._figure.axes) == 1: self._figure.add_axes(self._resid_ax) # get the residuals energy, chanwidths, resid, resid_err = self._fitter.residuals(sigma=sigma) # plot for each detector/dataset ymin, ymax = ([], []) for i in range(self._fitter.num_sets): det = self._fitter.detectors[i] self._resids.insert(det, ModelData(energy[i], resid[i], chanwidths[i], resid_err[i], self._resid_ax, color=self.colors[i], alpha=0.7, linewidth=0.9)) ymin.append((resid[i]-resid_err[i]).min()) ymax.append((resid[i]+resid_err[i]).max()) # the zero line self._resid_ax.axhline(0.0, color='black') self._resid_ax.set_xlabel('Energy [kev]', fontsize=PLOTFONTSIZE) if sigma: self._resid_ax.set_ylabel('Residuals [sigma]', fontsize=PLOTFONTSIZE) else: self._resid_ax.set_ylabel('Residuals [counts]', fontsize=PLOTFONTSIZE) # we have to set the y-axis range manually, because the y-axis # autoscale is broken (known issue) in matplotlib for this situation ymin = np.min(ymin) ymax = np.max(ymax) self._resid_ax.set_ylim((1.0-np.sign(ymin)*0.1)*ymin, (1.0+np.sign(ymax)*0.1)*ymax)
[docs] def hide_residuals(self): """Hide the fit residuals """ try: self._figure.delaxes(self._resid_ax) self._ax.xaxis.set_tick_params(which='both', labelbottom=True) self._ax.set_xlabel('Energy (keV)', fontsize=PLOTFONTSIZE) except: print('Residuals already hidden')
def _set_view(self): """Set the view properties """ self._ax.set_xlim(self._fitter.energy_range) self._ax.yaxis.set_tick_params(labelsize=PLOTFONTSIZE) self._ax.set_xscale('log') self._ax.set_yscale('log') self._ax.set_xlabel('Energy [kev]', fontsize=PLOTFONTSIZE) def _plot_spectral_model(self, num_samples=100, plot_components=True): """Plot the spectral model by sampling from the Gaussian approximation to the parameters' posterior. Args: num_samples (int, optional): The number of sample spectra. Default is 100. """ # clean plot and hide residuals if any warnings.filterwarnings("ignore", category=UserWarning) self._ax.clear() self.hide_residuals() num_comp = self._fitter.num_components comps = self._fitter.function_components name = self._fitter.function_name # if the number of model components is > 1, plot each one if (num_comp > 1) and (plot_components): energies, samples = self._fitter.sample_spectrum(which=self._view, num_samples=num_samples, components=True) self._spectrum_model = [ModelSamples(energies, samples[:,i,:], self._ax, label=comps[i], color=self.colors[i+1], alpha=0.1, lw=0.3) for i in range(num_comp)] samples = samples.sum(axis=1) else: # or just plot the function self._spectrum_model = [] energies, samples = self._fitter.sample_spectrum(which=self._view, num_samples=num_samples) y_max = samples.max(axis=(1,0)) self._spectrum_model.append(ModelSamples(energies, samples, self._ax, label=name, color=self.colors[0], alpha=0.1, lw=0.3)) self._set_view() # fix the alphas for the legend legend = self._ax.legend() for lh in legend.legendHandles: lh.set_alpha(1) lh.set_linewidth(1.0) if self._ax.get_ylim()[0] < self._min_y: self._ax.set_ylim(self._min_y, 10.0*y_max)