Skip to content

Commit

Permalink
Merge pull request #456 from xpsi-group/425-residual-cluster-plots
Browse files Browse the repository at this point in the history
Clusters plots with the residuals
  • Loading branch information
sguillot authored Dec 6, 2024
2 parents ebfeb3c 46072ee commit fac5459
Showing 1 changed file with 230 additions and 9 deletions.
239 changes: 230 additions & 9 deletions xpsi/PostProcessing/_residual.py
Original file line number Diff line number Diff line change
@@ -1,11 1,12 @@
from pylab import *
from ._global_imports import *

from scipy.ndimage import measurements
from ._signalplot import SignalPlot

class ResidualPlot(SignalPlot):
""" Plot the count data, the posterior-expected count signal, and residuals.
""" Plot the count data, the posterior-expected count signal, residuals, clusters and their related distributions.
The figure contains three panels which share phase as an x-axis:
The figure contains three upper panels which share phase as an x-axis:
* the top panel displays the data count numbers in joint channel-phase
intervals, identically split over two rotational phase cycles;
Expand All @@ -15,6 16,15 @@ class ResidualPlot(SignalPlot):
data and posterior-expected count signal over joint channel-phase
intervals.
If requested by plot_clusters, the figure also contains 2 lower panels
* the first panel displays two figures: the one on the left shows the
standardised residuals for which their absolute values overpass a chosen
threshold, and the one on the right present the estimated cluster sizes from
these residuals.
* the second panel displays two figures: the one on the left is the residual distribution
compared with an optimal gaussian one, and the one on the right shows a distribution of
cluster sizes from residuals which have absolute values above the chosen threshold
The following example is (improved) from `Riley et al. 2019 <https://ui.adsabs.harvard.edu/abs/2019ApJ...887L..21R/abstract>`_:
.. image:: _static/_residualplot.png
Expand Down Expand Up @@ -54,24 64,106 @@ def __init__(self,
data_cmap='inferno',
model_cmap='inferno',
residual_cmap='PuOr',
plot_clusters=False,
threshlim=2.0,
clusters_cmap='PuOr',
clustdist_cmap='PuOr',
mu=0.0,
sigma=1.0,
nbins=50,
**kwargs):
"""
Constructor method for plotting residuals.
:param str data_cmap:
Colormap name from :mod:`matplotlib` to use for the data count numbers
over joint channel-phase intervals.
:param str model_cmap:
Colormap name from :mod:`matplotlib` to use for the posterior-expected
count numbers over joint channel-phase intervals.
:param str residual_cmap:
Colormap name from :mod:`matplotlib` to use for the residuals between
the data and posterior-expected count numbers over joint
channel-phase intervals. A diverging colormap is recommended.
:param bool plot_clusters:
Plot cluster sizes and residual distribution?
:param float threshlim:
Threshold above which residuals are classified as clusters.
:param str clusters_cmap:
Colormap name from :mod:`matplotlib` to use for cluster sizes.
:param str clustdist_cmap:
Colormap name from :mod:`matplotlib` to use for cluster sizes distribution.
:param float mu:
Mean for the optimal gaussian distribution to compare with the residuals.
:param float sigma:
Standard deviation for the optimal gaussian distribution to compare with the residuals.
:param kwargs:
Keyword arguments for :class:`SignalPlot`.
"""
super(ResidualPlot, self).__init__(**kwargs)

# Setup the class
self._data_cmap = data_cmap
self._model_cmap = model_cmap
self._residual_cmap = residual_cmap
self._plot_clusters = plot_clusters

# Do you want to plot clusters ?
if self._plot_clusters:

# Add more columns/rows
cls = type(self)
cls.__rows__ = 5
cls.__ax_rows__ = 5
cls.__ax_columns__ = 5
cls.__height_ratios__ = [1] * 5
cls.__width_ratios__ = [50, 1, 50, 1, 1] # second column for colorbars
cls.__wspace__ = 0.1
cls.__hspace__ = 0.35

# Add parameters to plots
self._threshlim = threshlim
self._mu=mu
self._sigma=sigma
self._nbins=nbins
self._clusters_cmap = clusters_cmap
self._clustdist_cmap = clustdist_cmap

else:

# Restore sizes
cls = type(self)
cls.__rows__ = 3
cls.__columns__ = 1
cls.__ax_rows__ = 3
cls.__ax_columns__ = 2
cls.__height_ratios__ = [1]*3
cls.__width_ratios__ = [50, 1] # second column for colorbars
cls.__wspace__ = 0.025
cls.__hspace__ = 0.125

# Generate the axes for plotting
self._get_figure()

self._ax_data = self._add_subplot(0,0)
self._ax_data_cb = self._add_subplot(0,1)
self._ax_data = self._fig.add_subplot(self._gs[0,:-1])
self._ax_data_cb = self._fig.add_subplot(self._gs[0,-1])

self._ax_model = self._add_subplot(1,0)
self._ax_model_cb = self._add_subplot(1,1)
self._ax_model = self._fig.add_subplot(self._gs[1,:-1])
self._ax_model_cb = self._fig.add_subplot(self._gs[1,-1])

self._ax_resid = self._add_subplot(2,0)
self._ax_resid_cb = self._add_subplot(2,1)
self._ax_resid = self._fig.add_subplot(self._gs[2,:-1])
self._ax_resid_cb = self._fig.add_subplot(self._gs[2,-1])

# Prettify everything
self._ax_resid.set_xlabel(r'$\phi$ [cycles]')
for ax in (self._ax_data, self._ax_model):
ax.tick_params(axis='x', labelbottom=False)
Expand All @@ -82,6 174,43 @@ def __init__(self,
ax.xaxis.set_minor_locator(MultipleLocator(0.05))
ax.set_xlim([0.0,2.0])

# Handle cluster plots if required
if self._plot_clusters:

# Generate axes
self._ax_clres = self._fig.add_subplot(self._gs[3,:2])
self._ax_clust = self._fig.add_subplot(self._gs[3,2:-1], sharex = self._ax_clres)
self._ax_clust_cb = self._fig.add_subplot(self._gs[3,-1])

self._ax_rdist = self._fig.add_subplot(self._gs[4, :2])
self._ax_cdist = self._fig.add_subplot(self._gs[4, 2:-1])

# Prettify
for ax in (self._ax_data, self._ax_model, self._ax_clres, self._ax_clust):
ax.set_xlabel('$\phi$ [cycles]')

for ax in (self._ax_data, self._ax_model):
ax.tick_params(axis='x', labelbottom=True)

self._ax_clres.set_ylabel('channel')
self._ax_clres.xaxis.set_major_locator(MultipleLocator(0.2))
self._ax_clres.xaxis.set_minor_locator(MultipleLocator(0.05))
self._ax_clres.set_xlim([0.0,1.0])
self._ax_clres.set_title(r'|residuals| > {}'.format(self._threshlim))

self._ax_clust.set_title('Cluster sizes')
self._ax_clust.set_yticklabels([])
self._ax_clust.set_yticks([])

self._ax_rdist.set_title('Residual distribution')
self._ax_rdist.xaxis.set_major_locator(MultipleLocator(1.0))
self._ax_rdist.xaxis.set_minor_locator(MultipleLocator(0.5))
self._ax_rdist.set_xlabel('Residuals')

self._ax_cdist.set_title('Cluster sizes distribution')
self._ax_cdist.set_xlabel('Cluster sizes')
self._ax_cdist.tick_params(axis='y', labelright=True, labelleft=False)

if "yscale" in kwargs:
self.yscale = kwargs.get("yscale")
else:
Expand Down Expand Up @@ -134,6 263,9 @@ def finalize(self):
self._add_data()
self._add_expected_counts()
self._add_residuals()
if self._plot_clusters:
self._reveal_clusters()
self._disp_distributions()

def _set_vminmax(self):
""" Compute minimum and maximum for data and model colorbars. """
Expand Down Expand Up @@ -290,3 422,92 @@ def _add_residuals(self):
self._resid_cb.set_label(label=r'$(c_{ik}-d_{ik})/\sqrt{c_{ik}}$',
labelpad=15)

def _reveal_clusters(self):
""" Display clusters from residuals in the fourth panel. """
self._residuals = self.expected_counts - self._signal.data.counts
self._residuals /= _np.sqrt(self.expected_counts)
self._clusteresid = _np.abs( self._residuals ) >= self._threshlim
self._lw, self._num = measurements.label(self._clusteresid)
self._clustarea = measurements.sum(self._clusteresid, self._lw, index=arange(self._lw.max() 1))
self._affectedarea = self._clustarea[self._lw]
vmaxresid = _np.max( _np.abs( self._residuals ) )
vmaxarea = _np.max( self._affectedarea )

#Calculate channel edges by averaging:
channels = self._signal.data.channels
channel_edges = _np.zeros((len(self._signal.data.channels) 1))
channel_edges[1:len(channels)] = (channels[:len(channels)-1] channels[1:])/2.0
chandiff1 = (channels[1]-channels[0])/2.0
chandiff2 = (channels[len(channels)-1]-channels[len(channels)-2])/2.0
channel_edges[0] = channels[0]-chandiff1
channel_edges[len(channels)] = channels[len(channels)-1] chandiff2

clust1 = self._ax_clres.pcolormesh(self._signal.data.phases,
channel_edges,
_np.where(self._clusteresid, self._residuals, 0),
cmap = cm.get_cmap(self._residual_cmap),
vmin = -vmaxresid,
vmax = vmaxresid,
linewidth = 0,
rasterized = self._rasterized)
clust1.set_edgecolor('face')

clust2 = self._ax_clust.pcolormesh(self._signal.data.phases,
channel_edges,
self._affectedarea,
cmap = cm.get_cmap(self._clusters_cmap),
vmin = 0,
vmax = vmaxarea,
linewidth = 0,
rasterized = self._rasterized)
clust2.set_edgecolor('face')

self._ax_clres.set_ylim([_np.max([channel_edges[0],0.001]),
channel_edges[-1]])
self._ax_clres.set_yscale(self.yscale)
self._ax_clust.set_ylim([_np.max([channel_edges[0],0.001]),
channel_edges[-1]])

self._clust_cb = plt.colorbar(clust2, cax = self._ax_clust_cb,
ticks=AutoLocator())
self._clust_cb.ax.set_frame_on(True)
self._clust_cb.ax.yaxis.set_minor_locator(AutoMinorLocator())

self._clust_cb.set_label(label=r'cluster sizes for |residuals| > {}'.format(self._threshlim),
labelpad=15)

def _disp_distributions(self):
""" Display residual and cluster distributions in the fifth panel. """
self._residuals = self.expected_counts - self._signal.data.counts
self._residuals /= _np.sqrt(self.expected_counts)
self._clusteresid = _np.abs( self._residuals ) >= self._threshlim
self._lw, self._num = measurements.label(self._clusteresid)
self._clustarea = measurements.sum(self._clusteresid, self._lw, index=arange(self._lw.max() 1))
self._affectedarea = self._clustarea[self._lw]
vmaxresid = _np.max( _np.abs( self._residuals ) )
vmaxarea = _np.max( self._affectedarea )

if _np.abs(_np.amin(self._residuals))< _np.abs(_np.amax(self._residuals)):
minabsresid=(-1.0)*_np.amax(self._residuals)
maxabsresid=_np.amax(self._residuals)
else:
minabsresid=_np.amin(self._residuals)
maxabsresid=(-1.0)*_np.amin(self._residuals)

residhist, binhist = _np.histogram(self._residuals, bins=self._nbins, range=[minabsresid, maxabsresid])
centphase = (binhist[:-1] binhist[1:]) / 2
binsize = (maxabsresid-minabsresid)/(self._nbins)
scale=binsize*8640.0
f = 1/(self._sigma * _np.sqrt(2 * _np.pi)) * _np.exp( - (centphase - self._mu)**2 / (2 * self._sigma**2) )

totar = self._clustarea.flatten()
totar = totar.astype(int)
count_arr = _np.bincount(totar)

rdist1 = self._ax_rdist.step(centphase, residhist)
rdist2 = self._ax_rdist.plot(centphase, f*scale, linewidth=2, color='m')

cdist = self._ax_cdist.step(np.linspace(0, vmaxarea.astype(int), len(count_arr)),
count_arr,
where='mid',
)

0 comments on commit fac5459

Please sign in to comment.