Package mvpa :: Package algorithms :: Module cvtranserror
[hide private]
[frames] | no frames]

Source Code for Module mvpa.algorithms.cvtranserror

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Cross-validate a classifier on a dataset""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13  from mvpa.support.copy import deepcopy 
 14   
 15  from mvpa.measures.base import DatasetMeasure 
 16  from mvpa.datasets.splitters import NoneSplitter 
 17  from mvpa.base import warning 
 18  from mvpa.misc.state import StateVariable, Harvestable 
 19  from mvpa.misc.transformers import GrandMean 
 20   
 21  if __debug__: 
 22      from mvpa.base import debug 
 23   
 24   
25 -class CrossValidatedTransferError(DatasetMeasure, Harvestable):
26 """Classifier cross-validation. 27 28 This class provides a simple interface to cross-validate a classifier 29 on datasets generated by a splitter from a single source dataset. 30 31 Arbitrary performance/error values can be computed by specifying an error 32 function (used to compute an error value for each cross-validation fold) 33 and a combiner function that aggregates all computed error values across 34 cross-validation folds. 35 """ 36 37 results = StateVariable(enabled=False, doc= 38 """Store individual results in the state""") 39 splits = StateVariable(enabled=False, doc= 40 """Store the actual splits of the data. Can be memory expensive""") 41 transerrors = StateVariable(enabled=False, doc= 42 """Store copies of transerrors at each step""") 43 confusion = StateVariable(enabled=False, doc= 44 """Store total confusion matrix (if available)""") 45 training_confusion = StateVariable(enabled=False, doc= 46 """Store total training confusion matrix (if available)""") 47 samples_error = StateVariable(enabled=False, 48 doc="Per sample errors.") 49 50
51 - def __init__(self, 52 transerror, 53 splitter=None, 54 combiner='mean', 55 expose_testdataset=False, 56 harvest_attribs=None, 57 copy_attribs='copy', 58 **kwargs):
59 """ 60 :Parameters: 61 transerror: TransferError instance 62 Provides the classifier used for cross-validation. 63 splitter: Splitter | None 64 Used to split the dataset for cross-validation folds. By 65 convention the first dataset in the tuple returned by the 66 splitter is used to train the provided classifier. If the 67 first element is 'None' no training is performed. The second 68 dataset is used to generate predictions with the (trained) 69 classifier. If `None` (default) an instance of 70 :class:`~mvpa.datasets.splitters.NoneSplitter` is used. 71 combiner: Functor | 'mean' 72 Used to aggregate the error values of all cross-validation 73 folds. If 'mean' (default) the grand mean of the transfer 74 errors is computed. 75 expose_testdataset: bool 76 In the proper pipeline, classifier must not know anything 77 about testing data, but in some cases it might lead only 78 to marginal harm, thus migth wanted to be enabled (provide 79 testdataset for RFE to determine stopping point). 80 harvest_attribs: list of basestr 81 What attributes of call to store and return within 82 harvested state variable 83 copy_attribs: None | basestr 84 Force copying values of attributes on harvesting 85 **kwargs: 86 All additional arguments are passed to the 87 :class:`~mvpa.measures.base.DatasetMeasure` base class. 88 """ 89 DatasetMeasure.__init__(self, **kwargs) 90 Harvestable.__init__(self, harvest_attribs, copy_attribs) 91 92 if splitter is None: 93 self.__splitter = NoneSplitter() 94 else: 95 self.__splitter = splitter 96 97 if combiner == 'mean': 98 self.__combiner = GrandMean 99 else: 100 self.__combiner = combiner 101 102 self.__transerror = transerror 103 self.__expose_testdataset = expose_testdataset
104 105 # TODO: put back in ASAP 106 # def __repr__(self): 107 # """String summary over the object 108 # """ 109 # return """CrossValidatedTransferError / 110 # splitter: %s 111 # classifier: %s 112 # errorfx: %s 113 # combiner: %s""" % (indentDoc(self.__splitter), indentDoc(self.__clf), 114 # indentDoc(self.__errorfx), indentDoc(self.__combiner)) 115 116
117 - def _call(self, dataset):
118 """Perform cross-validation on a dataset. 119 120 'dataset' is passed to the splitter instance and serves as the source 121 dataset to generate split for the single cross-validation folds. 122 """ 123 # store the results of the splitprocessor 124 results = [] 125 self.splits = [] 126 127 # local bindings 128 states = self.states 129 clf = self.__transerror.clf 130 expose_testdataset = self.__expose_testdataset 131 132 # what states to enable in terr 133 terr_enable = [] 134 for state_var in ['confusion', 'training_confusion', 'samples_error']: 135 if states.isEnabled(state_var): 136 terr_enable += [state_var] 137 138 # charge states with initial values 139 summaryClass = clf._summaryClass 140 clf_hastestdataset = hasattr(clf, 'testdataset') 141 142 self.confusion = summaryClass() 143 self.training_confusion = summaryClass() 144 self.transerrors = [] 145 self.samples_error = dict([(id, []) for id in dataset.origids]) 146 147 # enable requested states in child TransferError instance (restored 148 # again below) 149 if len(terr_enable): 150 self.__transerror.states._changeTemporarily( 151 enable_states=terr_enable) 152 153 # splitter 154 for split in self.__splitter(dataset): 155 # only train classifier if splitter provides something in first 156 # element of tuple -- the is the behavior of TransferError 157 if states.isEnabled("splits"): 158 self.splits.append(split) 159 160 if states.isEnabled("transerrors"): 161 # copy first and then train, as some classifiers cannot be copied 162 # when already trained, e.g. SWIG'ed stuff 163 transerror = deepcopy(self.__transerror) 164 else: 165 transerror = self.__transerror 166 167 # assign testing dataset if given classifier can digest it 168 if clf_hastestdataset and expose_testdataset: 169 clf.testdataset = split[1] 170 pass 171 172 # run the beast 173 result = transerror(split[1], split[0]) 174 175 # unbind the testdataset from the classifier 176 if clf_hastestdataset and expose_testdataset: 177 clf.testdataset = None 178 179 # next line is important for 'self._harvest' call 180 self._harvest(locals()) 181 182 # XXX Look below -- may be we should have not auto added .? 183 # then transerrors also could be deprecated 184 if states.isEnabled("transerrors"): 185 self.transerrors.append(transerror) 186 187 # XXX: could be merged with next for loop using a utility class 188 # that can add dict elements into a list 189 if states.isEnabled("samples_error"): 190 for k, v in \ 191 transerror.states.getvalue("samples_error").iteritems(): 192 self.samples_error[k].append(v) 193 194 # pull in child states 195 for state_var in ['confusion', 'training_confusion']: 196 if states.isEnabled(state_var): 197 states.getvalue(state_var).__iadd__( 198 transerror.states.getvalue(state_var)) 199 200 if __debug__: 201 debug("CROSSC", "Split #%d: result %s" \ 202 % (len(results), `result`)) 203 results.append(result) 204 205 # Since we could have operated with a copy -- bind the last used one back 206 self.__transerror = transerror 207 208 # put states of child TransferError back into original config 209 if len(terr_enable): 210 self.__transerror.states._resetEnabledTemporarily() 211 212 self.results = results 213 """Store state variable if it is enabled""" 214 215 # Provide those labels_map if appropriate 216 try: 217 if states.isEnabled("confusion"): 218 states.confusion.labels_map = dataset.labels_map 219 if states.isEnabled("training_confusion"): 220 states.training_confusion.labels_map = dataset.labels_map 221 except: 222 pass 223 224 return self.__combiner(results)
225 226 227 splitter = property(fget=lambda self:self.__splitter, 228 doc="Access to the Splitter instance.") 229 transerror = property(fget=lambda self:self.__transerror, 230 doc="Access to the TransferError instance.") 231 combiner = property(fget=lambda self:self.__combiner, 232 doc="Access to the configured combiner.")
233