Package mvpa :: Package tests :: Module test_searchlight
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_searchlight

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA searchlight algorithm""" 
 10   
 11  from mvpa.base import externals 
 12  from mvpa.datasets.masked import MaskedDataset 
 13  from mvpa.measures.searchlight import Searchlight 
 14  from mvpa.datasets.splitters import NFoldSplitter 
 15  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 16  from mvpa.clfs.transerror import TransferError 
 17   
 18  from tests_warehouse import * 
 19  from tests_warehouse_clfs import * 
 20   
21 -class SearchlightTests(unittest.TestCase):
22
23 - def setUp(self):
24 self.dataset = datasets['3dlarge']
25 26
27 - def testSearchlight(self):
28 # compute N-1 cross-validation for each sphere 29 transerror = TransferError(sample_clf_lin) 30 cv = CrossValidatedTransferError( 31 transerror, 32 NFoldSplitter(cvtype=1)) 33 # contruct radius 1 searchlight 34 sl = Searchlight(cv, radius=1.0, transformer=N.array, 35 enable_states=['spheresizes']) 36 37 # run searchlight 38 results = sl(self.dataset) 39 40 # check for correct number of spheres 41 self.failUnless(len(results) == 106) 42 43 # check for chance-level performance across all spheres 44 self.failUnless(0.4 < results.mean() < 0.6) 45 46 # check resonable sphere sizes 47 self.failUnless(len(sl.spheresizes) == 106) 48 self.failUnless(max(sl.spheresizes) == 7) 49 self.failUnless(min(sl.spheresizes) == 4) 50 51 # check base-class state 52 self.failUnlessEqual(len(sl.raw_results), 106)
53 54
56 # compute N-1 cross-validation for each sphere 57 transerror = TransferError(sample_clf_lin) 58 cv = CrossValidatedTransferError( 59 transerror, 60 NFoldSplitter(cvtype=1), 61 combiner=N.array) 62 # contruct radius 1 searchlight 63 sl = Searchlight(cv, radius=1.0, transformer=N.array, 64 center_ids=[3,50]) 65 66 # run searchlight 67 results = sl(self.dataset) 68 69 # only two spheres but error for all CV-folds 70 self.failUnlessEqual(results.shape, (2, len(self.dataset.uniquechunks)))
71 72
73 - def testChiSquareSearchlight(self):
74 # only do partial to save time 75 if not externals.exists('scipy'): 76 return 77 78 from mvpa.misc.stats import chisquare 79 80 transerror = TransferError(sample_clf_lin) 81 cv = CrossValidatedTransferError( 82 transerror, 83 NFoldSplitter(cvtype=1), 84 enable_states=['confusion']) 85 86 87 def getconfusion(data): 88 cv(data) 89 return chisquare(cv.confusion.matrix)[0]
90 91 # contruct radius 1 searchlight 92 sl = Searchlight(getconfusion, radius=1.0, 93 center_ids=[3,50]) 94 95 # run searchlight 96 results = sl(self.dataset) 97 98 self.failUnless(len(results) == 2)
99 100 101
102 -def suite():
103 return unittest.makeSuite(SearchlightTests)
104 105 106 if __name__ == '__main__': 107 import runner 108