1
2
3
4
5
6
7
8
9 """Basic (f)MRI plotting with ability to interactively perform thresholding
10
11 """
12
13 import pylab as P
14 import numpy as N
15 import matplotlib as mpl
16
17 from mvpa.base import warning, externals
18
19 if externals.exists('nifti', raiseException=True):
20 from nifti import NiftiImage
21
22 _interactive_backends = ['GTKAgg', 'TkAgg']
23
24 -def plotMRI(background=None, background_mask=None, cmap_bg='gray',
25 overlay=None, overlay_mask=None, cmap_overlay='autumn',
26 vlim=(0.0, None), vlim_type=None,
27 do_stretch_colors=False,
28 add_info=True, add_hist=True, add_colorbar=True,
29 fig=None, interactive=None
30 ):
31 """Very basic plotting of 3D data with interactive thresholding.
32
33 Background/overlay could be nifti files names or NiftiImage
34 objects, or 3D ndarrays. if no mask provided, only non-0 elements
35 are plotted
36
37 :Parameters:
38 do_stretch_colors : bool
39 Stratch color range to the data (not just to visible data)
40 vlim
41 2 element tuple of low/upper bounds of values to plot
42 vlim_type : None or 'symneg_z'
43 If not None, then vlim would be treated accordingly:
44 symneg_z
45 z-score values of symmetric normal around 0, estimated
46 by symmetrizing negative part of the distribution, which
47 often could be assumed when total distribution is a mixture of
48 by-chance performance normal around 0, and some other in the
49 positive tail
50
51 Available colormaps are presented nicely on
52 http://www.scipy.org/Cookbook/Matplotlib/Show_colormaps
53
54 TODO:
55 * Make interface more attractive/usable
56 * allow multiple overlays... or just unify for them all to be just a list of entries
57 * handle cases properly when there is only one - background/overlay
58 """
59
60 if False:
61 impath = '/research/fusion/herrman/be37/fMRI'
62 background = NiftiImage('%s/anat_slices_brain_inbold.nii.gz' % impath)
63 background_mask = None
64 overlay = NiftiImage('/research/fusion/herrman/code/CCe-1.nii.gz')
65 overlay_mask = NiftiImage('%s/masks/example_func_brain_mask.nii.gz' % impath)
66
67 do_stretch_colors = False
68 add_info = True
69 add_hist = True
70 add_colorbar = True
71 cmap_bg = 'gray'
72 cmap_overlay = 'hot'
73
74 fig = None
75
76
77 vlim = [2.3, None]
78 vlim_type = 'symneg_z'
79 interactive = False
80
81
82
83
84 def handle_arg(arg):
85 """Helper which would read in NiftiImage if necessary
86 """
87 if isinstance(arg, basestring):
88 arg = NiftiImage(arg)
89 argshape = arg.data.shape
90
91 if len(argshape)<3:
92 arg.data = arg.data.reshape((1,)*(3-len(argshape)) + argshape)
93 if isinstance(arg, N.ndarray):
94 if len(arg.shape) != 3:
95 raise ValueError, "For now just handling 3D volumes"
96 return arg
97
98 bg = handle_arg(background)
99 if isinstance(bg, NiftiImage):
100
101 fov = (N.array(bg.header['pixdim']) * bg.header['dim'])[3:0:-1]
102
103 aspect = fov[2]/fov[1]
104
105 bg = bg.data[...,::-1,::-1]
106 else:
107 aspect = 1.0
108
109 if bg is not None:
110 bg_mask = handle_arg(background_mask)
111 if isinstance(bg_mask, NiftiImage):
112 bg_mask = bg_mask.data[...,::-1,::-1]
113 if bg_mask is not None:
114 bg_mask = bg_mask != 0
115 else:
116 bg_mask = bg != 0
117
118 func = handle_arg(overlay)
119
120 if func is not None:
121 if isinstance(func, NiftiImage):
122 func = func.data[..., ::-1, :]
123
124 func_mask = handle_arg(overlay_mask)
125 if isinstance(func_mask, NiftiImage):
126 func_mask = func_mask.data[..., ::-1, :]
127 if func_mask is not None:
128 func_mask = func_mask != 0
129 else:
130 func_mask = func != 0
131
132
133
134 vlim = list(vlim)
135 vlim_orig = vlim[:]
136 add_dist2hist = []
137 if isinstance(vlim_type, basestring):
138 if vlim_type == 'symneg_z':
139 func_masked = func[func_mask]
140 fnonpos = func_masked[func_masked<=0]
141 fneg = func_masked[func_masked<0]
142
143 fsym = N.hstack((-fneg, fnonpos))
144 nfsym = len(fsym)
145
146 std = N.sqrt(N.mean(abs(fsym)**2))
147
148 for i,v in enumerate(vlim):
149 if v is not None:
150 vlim[i] = std * v
151
152 add_dist2hist = [(lambda x: nfsym/(N.sqrt(2*N.pi)*std)*N.exp(-(x**2)/(2*std**2)),
153 {})]
154 else:
155 raise ValueError, 'Unknown specification of vlim=%s' % vlim + \
156 ' Known is: symneg'
157
158
159 class Plotter(object):
160 """
161 TODO
162 """
163
164
165
166 def __init__(self, _locals):
167 """TODO"""
168 self._locals = _locals
169 self.fig = _locals['fig']
170
171 def do_plot(self):
172 """TODO"""
173
174 vlim = self._locals['vlim']
175 bg = self._locals['bg']
176 bg_mask = self._locals['bg_mask']
177
178
179 if N.isscalar(vlim): vlim = (vlim, None)
180 if vlim[0] is None: vlim = (N.min(func), vlim[1])
181 if vlim[1] is None: vlim = (vlim[0], N.max(func))
182 invert = vlim[1] < vlim[0]
183 if invert:
184 vlim = (vlim[1], vlim[0])
185 print "Not yet fully supported"
186
187
188 if vlim[0] < N.min(func[func_mask]):
189 vlim = list(vlim)
190 vlim[0] = N.min(func[func_mask])
191 vlim = tuple(vlim)
192
193 bound_above = (max(vlim) < N.max(func))
194 bound_below = (min(vlim) > N.min(func))
195
196
197
198 cmap_ = cmap_overlay
199 if not bound_below and bound_above:
200 if cmap_.endswith('_r'):
201 cmap_ = cmap_[:-2]
202 else:
203 cmap_ += '_r'
204
205 func_cmap = eval("P.cm.%s" % cmap_)
206 bg_cmap = eval("P.cm.%s" % cmap_bg)
207
208
209 if do_stretch_colors:
210 clim = (N.min(func), N.max(func))
211 else:
212 clim = vlim
213
214
215
216 extend, thresh_str = {
217 (True, True) : ('both', 'x in [%.3g, %.3g]' % tuple(vlim)),
218 (True, False): ('min', 'x in [%.3g, +inf]' % vlim[0]),
219 (False, True): ('max', 'x in (-inf, %.3g]' % vlim[1]),
220 (False, False): ('neither', 'none') }[(bound_below,
221 bound_above)]
222
223
224
225 dshape = func.shape
226 nslices = func.shape[0]
227
228
229 ndcolumns = ncolumns = int(N.sqrt(nslices))
230 nrows = int(N.ceil(nslices*1.0/ncolumns))
231
232
233 ncolumns += int(add_info or add_hist)
234
235
236 if add_info and add_hist and nrows < 2:
237 nrows = 2
238
239
240 if P.matplotlib.get_backend() in _interactive_backends:
241 P.ioff()
242
243 if self.fig is None:
244 self.fig = P.figure(facecolor='white', figsize=(4*ncolumns, 4*nrows))
245 else:
246 self.fig.clf()
247 fig = self.fig
248
249
250
251
252 thresholder = lambda x: N.logical_and(x>=vlim[0], x<=vlim[1]) ^ invert
253
254
255
256 for si in range(nslices)[::-1]:
257 ax = fig.add_subplot(nrows, ncolumns,
258 (si/ndcolumns)*ncolumns + si%ndcolumns + 1, frame_on=False)
259 ax.axison = False
260 slice_bg = bg[si]
261 slice_bg_ = N.ma.masked_array(slice_bg,
262 mask=N.logical_not(bg_mask[si]))
263
264 slice_sl = func[si]
265
266 in_thresh = thresholder(slice_sl)
267 out_thresh = N.logical_not(in_thresh)
268 slice_sl_ = N.ma.masked_array(slice_sl,
269 mask=N.logical_or(out_thresh,
270 N.logical_not(func_mask[si])))
271
272 kwargs = dict(aspect=aspect, origin='lower')
273
274
275
276 im = ax.imshow(N.ones(slice_sl_.shape),
277 cmap=bg_cmap,
278 extent=(0, slice_bg.shape[0],
279 0, slice_bg.shape[1]),
280 **kwargs)
281 im.set_clim((0,1))
282
283
284 ax.imshow(slice_bg_,
285 interpolation='bilinear',
286 cmap=bg_cmap,
287 **kwargs)
288
289 im = ax.imshow(slice_sl_,
290 interpolation='nearest',
291 cmap=func_cmap,
292 alpha=0.8,
293 extent=(0, slice_bg.shape[0],
294 0, slice_bg.shape[1]),
295 **kwargs)
296 im.set_clim(*clim)
297
298 if si == 0:
299 im0 = im
300
301 if add_colorbar:
302 cb = P.colorbar(im0, shrink=0.8, pad=0.0, drawedges=False,
303 extend=extend, cmap=func_cmap)
304 cb.set_clim(*clim)
305
306 func_masked = func[func_mask]
307
308
309 if add_hist:
310 self.hist_sp = fig.add_subplot(nrows, ncolumns, ncolumns, frame_on=True)
311 minv, maxv = N.min(func_masked), N.max(func_masked)
312 if minv<0 and maxv>0:
313 maxx = max(-minv, maxv)
314 range_ = (-maxx, maxx)
315 else:
316 range_ = (minv, maxv)
317 H = N.histogram(func_masked, range=range_, bins=31)
318 H2 = P.hist(func_masked, bins=H[1], align='center', facecolor='r', hold=True)
319 for a, kwparams in add_dist2hist:
320 dbin = (H[1][1] - H[1][0])
321 P.plot(H2[1], [a(x) * dbin for x in H2[1]], **kwparams)
322 if add_colorbar:
323 cbrgba = cb.to_rgba(H2[1])
324 for face, facecolor, value in zip(H2[2], cbrgba, H2[1]):
325 if not thresholder(value):
326 color = None
327 else:
328 color = facecolor
329 face.set_facecolor(color)
330
331
332
333
334 func_thr = func[N.logical_and(func_mask, thresholder(func))]
335 if add_info and len(func_thr):
336 ax = fig.add_subplot(nrows, ncolumns, (1+int(add_hist))*ncolumns, frame_on=False)
337
338
339 ax.axison = False
340
341
342
343
344 stats = {'v':len(func_masked),
345 'vt': len(func_thr),
346 'm': N.mean(func_masked),
347 'mt': N.mean(func_thr),
348 'min': N.min(func_masked),
349 'mint': N.min(func_thr),
350 'max': N.max(func_masked),
351 'maxt': N.max(func_thr),
352 'mm': N.median(func_masked),
353 'mmt': N.median(func_thr),
354 'std': N.std(func_masked),
355 'stdt': N.std(func_thr),
356 'sthr': thresh_str}
357 P.text(0, 0.5, """
358 Original:
359 voxels = %(v)d
360 range = [%(min).3g, %(max).3g]
361 mean = %(m).3g
362 median = %(mm).3g
363 std = %(std).3g
364
365 Thresholded: %(sthr)s:
366 voxels = %(vt)d
367 range = [%(mint).3g, %(maxt).3g]
368 median = %(mt).3g
369 mean = %(mmt).3g
370 std = %(stdt).3g
371 """ % stats,
372 horizontalalignment='left',
373 verticalalignment='center',
374 transform = ax.transAxes,
375 fontsize=14)
376
377
378 fig.subplots_adjust(left=0.01, right=0.95, bottom=0.01, hspace=0.01)
379 if ncolumns - int(add_info or add_hist) < 2:
380 fig.subplots_adjust(wspace=0.4)
381 else:
382 fig.subplots_adjust(wspace=0.1)
383
384 if P.matplotlib.get_backend() in _interactive_backends:
385 P.draw()
386 P.ion()
387
388 def on_click(self, event):
389 """Actions to perform on click
390 """
391 if id(event.inaxes) != id(plotter.hist_sp):
392 return
393 xdata, ydata, button = event.xdata, event.ydata, event.button
394 print xdata, ydata, button
395 vlim = self._locals['vlim']
396 if button == 1:
397 vlim[0] = xdata
398 elif button == 3:
399 vlim[1] = xdata
400 elif button == 2:
401 vlim[0], vlim[1] = vlim[1], vlim[0]
402 self.do_plot()
403
404 plotter = Plotter(locals())
405 plotter.do_plot()
406
407 if interactive is None:
408 interactive = P.matplotlib.get_backend() in _interactive_backends
409
410
411 if interactive:
412
413 P.connect('button_press_event', plotter.on_click)
414 P.show()
415
416 return plotter.fig
417