Package mvpa :: Package clfs :: Module base
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.base

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Base class for all classifiers. 
 10   
 11  At the moment, regressions are treated just as a special case of 
 12  classifier (or vise verse), so the same base class `Classifier` is 
 13  utilized for both kinds. 
 14  """ 
 15   
 16  __docformat__ = 'restructuredtext' 
 17   
 18  import numpy as N 
 19   
 20  from mvpa.support.copy import deepcopy 
 21   
 22  import time 
 23   
 24  from mvpa.misc.support import idhash 
 25  from mvpa.misc.state import StateVariable, ClassWithCollections 
 26  from mvpa.misc.param import Parameter 
 27   
 28  from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics 
 29   
 30  from mvpa.base import warning 
 31   
 32  if __debug__: 
 33      from mvpa.base import debug 
 34   
 35   
36 -class Classifier(ClassWithCollections):
37 """Abstract classifier class to be inherited by all classifiers 38 """ 39 40 # Kept separate from doc to don't pollute help(clf), especially if 41 # we including help for the parent class 42 _DEV__doc__ = """ 43 Required behavior: 44 45 For every classifier is has to be possible to be instantiated without 46 having to specify the training pattern. 47 48 Repeated calls to the train() method with different training data have to 49 result in a valid classifier, trained for the particular dataset. 50 51 It must be possible to specify all classifier parameters as keyword 52 arguments to the constructor. 53 54 Recommended behavior: 55 56 Derived classifiers should provide access to *values* -- i.e. that 57 information that is finally used to determine the predicted class label. 58 59 Michael: Maybe it works well if each classifier provides a 'values' 60 state member. This variable is a list as long as and in same order 61 as Dataset.uniquelabels (training data). Each item in the list 62 corresponds to the likelyhood of a sample to belong to the 63 respective class. However the semantics might differ between 64 classifiers, e.g. kNN would probably store distances to class- 65 neighbors, where PLR would store the raw function value of the 66 logistic function. So in the case of kNN low is predictive and for 67 PLR high is predictive. Don't know if there is the need to unify 68 that. 69 70 As the storage and/or computation of this information might be 71 demanding its collection should be switchable and off be default. 72 73 Nomenclature 74 * predictions : corresponds to the quantized labels if classifier spits 75 out labels by .predict() 76 * values : might be different from predictions if a classifier's predict() 77 makes a decision based on some internal value such as 78 probability or a distance. 79 """ 80 # Dict that contains the parameters of a classifier. 81 # This shall provide an interface to plug generic parameter optimizer 82 # on all classifiers (e.g. grid- or line-search optimizer) 83 # A dictionary is used because Michael thinks that access by name is nicer. 84 # Additionally Michael thinks ATM that additional information might be 85 # necessary in some situations (e.g. reasonably predefined parameter range, 86 # minimal iteration stepsize, ...), therefore the value to each key should 87 # also be a dict or we should use mvpa.misc.param.Parameter'... 88 89 trained_labels = StateVariable(enabled=True, 90 doc="Set of unique labels it has been trained on") 91 92 trained_nsamples = StateVariable(enabled=True, 93 doc="Number of samples it has been trained on") 94 95 trained_dataset = StateVariable(enabled=False, 96 doc="The dataset it has been trained on") 97 98 training_confusion = StateVariable(enabled=False, 99 doc="Confusion matrix of learning performance") 100 101 predictions = StateVariable(enabled=True, 102 doc="Most recent set of predictions") 103 104 values = StateVariable(enabled=True, 105 doc="Internal classifier values the most recent " + 106 "predictions are based on") 107 108 training_time = StateVariable(enabled=True, 109 doc="Time (in seconds) which took classifier to train") 110 111 predicting_time = StateVariable(enabled=True, 112 doc="Time (in seconds) which took classifier to predict") 113 114 feature_ids = StateVariable(enabled=False, 115 doc="Feature IDS which were used for the actual training.") 116 117 _clf_internals = [] 118 """Describes some specifics about the classifier -- is that it is 119 doing regression for instance....""" 120 121 regression = Parameter(False, allowedtype='bool', 122 doc="""Either to use 'regression' as regression. By default any 123 Classifier-derived class serves as a classifier, so regression 124 does binary classification.""", index=1001) 125 126 # TODO: make it available only for actually retrainable classifiers 127 retrainable = Parameter(False, allowedtype='bool', 128 doc="""Either to enable retraining for 'retrainable' classifier.""", 129 index=1002) 130 131
132 - def __init__(self, **kwargs):
133 """Cheap initialization. 134 """ 135 ClassWithCollections.__init__(self, **kwargs) 136 137 138 self.__trainednfeatures = None 139 """Stores number of features for which classifier was trained. 140 If None -- it wasn't trained at all""" 141 142 self._setRetrainable(self.params.retrainable, force=True) 143 144 if self.params.regression: 145 for statevar in [ "trained_labels"]: #, "training_confusion" ]: 146 if self.states.isEnabled(statevar): 147 if __debug__: 148 debug("CLF", 149 "Disabling state %s since doing regression, " % 150 statevar + "not classification") 151 self.states.disable(statevar) 152 self._summaryClass = RegressionStatistics 153 else: 154 self._summaryClass = ConfusionMatrix 155 clf_internals = self._clf_internals 156 if 'regression' in clf_internals and not ('binary' in clf_internals): 157 # regressions are used as binary classifiers if not 158 # asked to perform regression explicitly 159 # We need a copy of the list, so we don't override class-wide 160 self._clf_internals = clf_internals + ['binary']
161 162 # deprecate 163 #self.__trainedidhash = None 164 #"""Stores id of the dataset on which it was trained to signal 165 #in trained() if it was trained already on the same dataset""" 166 167
168 - def __str__(self):
169 if __debug__ and 'CLF_' in debug.active: 170 return "%s / %s" % (repr(self), super(Classifier, self).__str__()) 171 else: 172 return repr(self)
173
174 - def __repr__(self, prefixes=[]):
175 return super(Classifier, self).__repr__(prefixes=prefixes)
176 177
178 - def _pretrain(self, dataset):
179 """Functionality prior to training 180 """ 181 # So we reset all state variables and may be free up some memory 182 # explicitly 183 params = self.params 184 if not params.retrainable: 185 self.untrain() 186 else: 187 # just reset the states, do not untrain 188 self.states.reset() 189 if not self.__changedData_isset: 190 self.__resetChangedData() 191 _changedData = self._changedData 192 __idhashes = self.__idhashes 193 __invalidatedChangedData = self.__invalidatedChangedData 194 195 # if we don't know what was changed we need to figure 196 # them out 197 if __debug__: 198 debug('CLF_', "IDHashes are %s" % (__idhashes)) 199 200 # Look at the data if any was changed 201 for key, data_ in (('traindata', dataset.samples), 202 ('labels', dataset.labels)): 203 _changedData[key] = self.__wasDataChanged(key, data_) 204 # if those idhashes were invalidated by retraining 205 # we need to adjust _changedData accordingly 206 if __invalidatedChangedData.get(key, False): 207 if __debug__ and not _changedData[key]: 208 debug('CLF_', 'Found that idhash for %s was ' 209 'invalidated by retraining' % key) 210 _changedData[key] = True 211 212 # Look at the parameters 213 for col in self._paramscols: 214 changedParams = self._collections[col].whichSet() 215 if len(changedParams): 216 _changedData[col] = changedParams 217 218 self.__invalidatedChangedData = {} # reset it on training 219 220 if __debug__: 221 debug('CLF_', "Obtained _changedData is %s" 222 % (self._changedData)) 223 224 if not params.regression and 'regression' in self._clf_internals \ 225 and not self.states.isEnabled('trained_labels'): 226 # if classifier internally does regression we need to have 227 # labels it was trained on 228 if __debug__: 229 debug("CLF", "Enabling trained_labels state since it is needed") 230 self.states.enable('trained_labels')
231 232
233 - def _posttrain(self, dataset):
234 """Functionality post training 235 236 For instance -- computing confusion matrix 237 :Parameters: 238 dataset : Dataset 239 Data which was used for training 240 """ 241 if self.states.isEnabled('trained_labels'): 242 self.trained_labels = dataset.uniquelabels 243 244 self.trained_dataset = dataset 245 self.trained_nsamples = dataset.nsamples 246 247 # needs to be assigned first since below we use predict 248 self.__trainednfeatures = dataset.nfeatures 249 250 if __debug__ and 'CHECK_TRAINED' in debug.active: 251 self.__trainedidhash = dataset.idhash 252 253 if self.states.isEnabled('training_confusion') and \ 254 not self.states.isSet('training_confusion'): 255 # we should not store predictions for training data, 256 # it is confusing imho (yoh) 257 self.states._changeTemporarily( 258 disable_states=["predictions"]) 259 if self.params.retrainable: 260 # we would need to recheck if data is the same, 261 # XXX think if there is a way to make this all 262 # efficient. For now, probably, retrainable 263 # classifiers have no chance but not to use 264 # training_confusion... sad 265 self.__changedData_isset = False 266 predictions = self.predict(dataset.samples) 267 self.states._resetEnabledTemporarily() 268 self.training_confusion = self._summaryClass( 269 targets=dataset.labels, 270 predictions=predictions) 271 272 try: 273 self.training_confusion.labels_map = dataset.labels_map 274 except: 275 pass 276 277 if self.states.isEnabled('feature_ids'): 278 self.feature_ids = self._getFeatureIds()
279 280
281 - def _getFeatureIds(self):
282 """Virtual method to return feature_ids used while training 283 284 Is not intended to be called anywhere but from _posttrain, 285 thus classifier is assumed to be trained at this point 286 """ 287 # By default all features are used 288 return range(self.__trainednfeatures)
289 290
291 - def summary(self):
292 """Providing summary over the classifier""" 293 294 s = "Classifier %s" % self 295 states = self.states 296 states_enabled = states.enabled 297 298 if self.trained: 299 s += "\n trained" 300 if states.isSet('training_time'): 301 s += ' in %.3g sec' % states.training_time 302 s += ' on data with' 303 if states.isSet('trained_labels'): 304 s += ' labels:%s' % list(states.trained_labels) 305 306 nsamples, nchunks = None, None 307 if states.isSet('trained_nsamples'): 308 nsamples = states.trained_nsamples 309 if states.isSet('trained_dataset'): 310 td = states.trained_dataset 311 nsamples, nchunks = td.nsamples, len(td.uniquechunks) 312 if nsamples is not None: 313 s += ' #samples:%d' % nsamples 314 if nchunks is not None: 315 s += ' #chunks:%d' % nchunks 316 317 s += " #features:%d" % self.__trainednfeatures 318 if states.isSet('feature_ids'): 319 s += ", used #features:%d" % len(states.feature_ids) 320 if states.isSet('training_confusion'): 321 s += ", training error:%.3g" % states.training_confusion.error 322 else: 323 s += "\n not yet trained" 324 325 if len(states_enabled): 326 s += "\n enabled states:%s" % ', '.join([str(states[x]) 327 for x in states_enabled]) 328 return s
329 330
331 - def clone(self):
332 """Create full copy of the classifier. 333 334 It might require classifier to be untrained first due to 335 present SWIG bindings. 336 337 TODO: think about proper re-implementation, without enrollment of deepcopy 338 """ 339 try: 340 return deepcopy(self) 341 except: 342 self.untrain() 343 return deepcopy(self)
344 345
346 - def _train(self, dataset):
347 """Function to be actually overridden in derived classes 348 """ 349 raise NotImplementedError
350 351
352 - def train(self, dataset):
353 """Train classifier on a dataset 354 355 Shouldn't be overridden in subclasses unless explicitly needed 356 to do so 357 """ 358 if __debug__: 359 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s", 360 msgargs={'clf':self, 'dataset':dataset}) 361 362 self._pretrain(dataset) 363 364 # remember the time when started training 365 t0 = time.time() 366 367 if dataset.nfeatures > 0: 368 result = self._train(dataset) 369 else: 370 warning("Trying to train on dataset with no features present") 371 if __debug__: 372 debug("CLF", 373 "No features present for training, no actual training " \ 374 "is called") 375 result = None 376 377 self.training_time = time.time() - t0 378 self._posttrain(dataset) 379 return result
380 381
382 - def _prepredict(self, data):
383 """Functionality prior prediction 384 """ 385 if not ('notrain2predict' in self._clf_internals): 386 # check if classifier was trained if that is needed 387 if not self.trained: 388 raise ValueError, \ 389 "Classifier %s wasn't yet trained, therefore can't " \ 390 "predict" % self 391 nfeatures = data.shape[1] 392 # check if number of features is the same as in the data 393 # it was trained on 394 if nfeatures != self.__trainednfeatures: 395 raise ValueError, \ 396 "Classifier %s was trained on data with %d features, " % \ 397 (self, self.__trainednfeatures) + \ 398 "thus can't predict for %d features" % nfeatures 399 400 401 if self.params.retrainable: 402 if not self.__changedData_isset: 403 self.__resetChangedData() 404 _changedData = self._changedData 405 _changedData['testdata'] = \ 406 self.__wasDataChanged('testdata', data) 407 if __debug__: 408 debug('CLF_', "prepredict: Obtained _changedData is %s" 409 % (_changedData))
410 411
412 - def _postpredict(self, data, result):
413 """Functionality after prediction is computed 414 """ 415 self.predictions = result 416 if self.params.retrainable: 417 self.__changedData_isset = False
418
419 - def _predict(self, data):
420 """Actual prediction 421 """ 422 raise NotImplementedError
423 424
425 - def predict(self, data):
426 """Predict classifier on data 427 428 Shouldn't be overridden in subclasses unless explicitly needed 429 to do so. Also subclasses trying to call super class's predict 430 should call _predict if within _predict instead of predict() 431 since otherwise it would loop 432 """ 433 data = N.asarray(data) 434 if __debug__: 435 debug("CLF", "Predicting classifier %(clf)s on data %(data)s", 436 msgargs={'clf':self, 'data':data.shape}) 437 438 # remember the time when started computing predictions 439 t0 = time.time() 440 441 states = self.states 442 # to assure that those are reset (could be set due to testing 443 # post-training) 444 states.reset(['values', 'predictions']) 445 446 self._prepredict(data) 447 448 if self.__trainednfeatures > 0 \ 449 or 'notrain2predict' in self._clf_internals: 450 result = self._predict(data) 451 else: 452 warning("Trying to predict using classifier trained on no features") 453 if __debug__: 454 debug("CLF", 455 "No features were present for training, prediction is " \ 456 "bogus") 457 result = [None]*data.shape[0] 458 459 states.predicting_time = time.time() - t0 460 461 if 'regression' in self._clf_internals and not self.params.regression: 462 # We need to convert regression values into labels 463 # XXX unify may be labels -> internal_labels conversion. 464 #if len(self.trained_labels) != 2: 465 # raise RuntimeError, "Ask developer to implement for " \ 466 # "multiclass mapping from regression into classification" 467 468 # must be N.array so we copy it to assign labels directly 469 # into labels, or should we just recreate "result"??? 470 result_ = N.array(result) 471 if states.isEnabled('values'): 472 # values could be set by now so assigning 'result' would 473 # be misleading 474 if not states.isSet('values'): 475 states.values = result_.copy() 476 else: 477 # it might be the values are pointing to result at 478 # the moment, so lets assure this silly way that 479 # they do not overlap 480 states.values = states.values.copy() 481 482 trained_labels = self.trained_labels 483 for i, value in enumerate(result): 484 dists = N.abs(value - trained_labels) 485 result[i] = trained_labels[N.argmin(dists)] 486 487 if __debug__: 488 debug("CLF_", "Converted regression result %(result_)s " 489 "into labels %(result)s for %(self_)s", 490 msgargs={'result_':result_, 'result':result, 491 'self_': self}) 492 493 self._postpredict(data, result) 494 return result
495 496 # deprecate ???
497 - def isTrained(self, dataset=None):
498 """Either classifier was already trained. 499 500 MUST BE USED WITH CARE IF EVER""" 501 if dataset is None: 502 # simply return if it was trained on anything 503 return not self.__trainednfeatures is None 504 else: 505 res = (self.__trainednfeatures == dataset.nfeatures) 506 if __debug__ and 'CHECK_TRAINED' in debug.active: 507 res2 = (self.__trainedidhash == dataset.idhash) 508 if res2 != res: 509 raise RuntimeError, \ 510 "isTrained is weak and shouldn't be relied upon. " \ 511 "Got result %b although comparing of idhash says %b" \ 512 % (res, res2) 513 return res
514 515
516 - def _regressionIsBogus(self):
517 """Some classifiers like BinaryClassifier can't be used for 518 regression""" 519 520 if self.params.regression: 521 raise ValueError, "Regression mode is meaningless for %s" % \ 522 self.__class__.__name__ + " thus don't enable it"
523 524 525 @property
526 - def trained(self):
527 """Either classifier was already trained""" 528 return self.isTrained()
529
530 - def untrain(self):
531 """Reset trained state""" 532 self.__trainednfeatures = None 533 # probably not needed... retrainable shouldn't be fully untrained 534 # or should be??? 535 #if self.params.retrainable: 536 # # ??? don't duplicate the code ;-) 537 # self.__idhashes = {'traindata': None, 'labels': None, 538 # 'testdata': None, 'testtraindata': None} 539 super(Classifier, self).reset()
540 541
542 - def getSensitivityAnalyzer(self, **kwargs):
543 """Factory method to return an appropriate sensitivity analyzer for 544 the respective classifier.""" 545 raise NotImplementedError
546 547 548 # 549 # Methods which are needed for retrainable classifiers 550 #
551 - def _setRetrainable(self, value, force=False):
552 """Assign value of retrainable parameter 553 554 If retrainable flag is to be changed, classifier has to be 555 untrained. Also internal attributes such as _changedData, 556 __changedData_isset, and __idhashes should be initialized if 557 it becomes retrainable 558 """ 559 pretrainable = self.params['retrainable'] 560 if (force or value != pretrainable.value) \ 561 and 'retrainable' in self._clf_internals: 562 if __debug__: 563 debug("CLF_", "Setting retrainable to %s" % value) 564 if 'meta' in self._clf_internals: 565 warning("Retrainability is not yet crafted/tested for " 566 "meta classifiers. Unpredictable behavior might occur") 567 # assure that we don't drag anything behind 568 if self.trained: 569 self.untrain() 570 states = self.states 571 if not value and states.isKnown('retrained'): 572 states.remove('retrained') 573 states.remove('repredicted') 574 if value: 575 if not 'retrainable' in self._clf_internals: 576 warning("Setting of flag retrainable for %s has no effect" 577 " since classifier has no such capability. It would" 578 " just lead to resources consumption and slowdown" 579 % self) 580 states.add(StateVariable(enabled=True, 581 name='retrained', 582 doc="Either retrainable classifier was retrained")) 583 states.add(StateVariable(enabled=True, 584 name='repredicted', 585 doc="Either retrainable classifier was repredicted")) 586 587 pretrainable.value = value 588 589 # if retrainable we need to keep track of things 590 if value: 591 self.__idhashes = {'traindata': None, 'labels': None, 592 'testdata': None} #, 'testtraindata': None} 593 if __debug__ and 'CHECK_RETRAIN' in debug.active: 594 # ??? it is not clear though if idhash is faster than 595 # simple comparison of (dataset != __traineddataset).any(), 596 # but if we like to get rid of __traineddataset then we 597 # should use idhash anyways 598 self.__trained = self.__idhashes.copy() # just same Nones 599 self.__resetChangedData() 600 self.__invalidatedChangedData = {} 601 elif 'retrainable' in self._clf_internals: 602 #self.__resetChangedData() 603 self.__changedData_isset = False 604 self._changedData = None 605 self.__idhashes = None 606 if __debug__ and 'CHECK_RETRAIN' in debug.active: 607 self.__trained = None
608
609 - def __resetChangedData(self):
610 """For retrainable classifier we keep track of what was changed 611 This function resets that dictionary 612 """ 613 if __debug__: 614 debug('CLF_', 615 'Retrainable: resetting flags on either data was changed') 616 keys = self.__idhashes.keys() + self._paramscols 617 # we might like to just reinit values to False??? 618 #_changedData = self._changedData 619 #if isinstance(_changedData, dict): 620 # for key in _changedData.keys(): 621 # _changedData[key] = False 622 self._changedData = dict(zip(keys, [False]*len(keys))) 623 self.__changedData_isset = False
624 625
626 - def __wasDataChanged(self, key, entry, update=True):
627 """Check if given entry was changed from what known prior. 628 629 If so -- store only the ones needed for retrainable beastie 630 """ 631 idhash_ = idhash(entry) 632 __idhashes = self.__idhashes 633 634 changed = __idhashes[key] != idhash_ 635 if __debug__ and 'CHECK_RETRAIN' in debug.active: 636 __trained = self.__trained 637 changed2 = entry != __trained[key] 638 if isinstance(changed2, N.ndarray): 639 changed2 = changed2.any() 640 if changed != changed2 and not changed: 641 raise RuntimeError, \ 642 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\ 643 'values %s!=%s %s' % \ 644 (key, idhash_, __idhashes[key], changed, 645 entry, __trained[key], changed2) 646 if update: 647 __trained[key] = entry 648 649 if __debug__ and changed: 650 debug('CLF_', "Changed %s from %s to %s.%s" 651 % (key, __idhashes[key], idhash_, 652 ('','updated')[int(update)])) 653 if update: 654 __idhashes[key] = idhash_ 655 656 return changed
657 658 659 # def __updateHashIds(self, key, data): 660 # """Is twofold operation: updates hashid if was said that it changed. 661 # 662 # or if it wasn't said that data changed, but CHECK_RETRAIN and it found 663 # to be changed -- raise Exception 664 # """ 665 # 666 # check_retrain = __debug__ and 'CHECK_RETRAIN' in debug.active 667 # chd = self._changedData 668 # 669 # # we need to updated idhashes 670 # if chd[key] or check_retrain: 671 # keychanged = self.__wasDataChanged(key, data) 672 # if check_retrain and keychanged and not chd[key]: 673 # raise RuntimeError, \ 674 # "Data %s found changed although wasn't " \ 675 # "labeled as such" % key 676 677 678 # 679 # Additional API which is specific only for retrainable classifiers. 680 # For now it would just puke if asked from not retrainable one. 681 # 682 # Might come useful and efficient for statistics testing, so if just 683 # labels of dataset changed, then 684 # self.retrain(dataset, labels=True) 685 # would cause efficient retraining (no kernels recomputed etc) 686 # and subsequent self.repredict(data) should be also quite fase ;-) 687
688 - def retrain(self, dataset, **kwargs):
689 """Helper to avoid check if data was changed actually changed 690 691 Useful if just some aspects of classifier were changed since 692 its previous training. For instance if dataset wasn't changed 693 but only classifier parameters, then kernel matrix does not 694 have to be computed. 695 696 Words of caution: classifier must be previously trained, 697 results always should first be compared to the results on not 698 'retrainable' classifier (without calling retrain). Some 699 additional checks are enabled if debug id 'CHECK_RETRAIN' is 700 enabled, to guard against obvious mistakes. 701 702 :Parameters: 703 kwargs 704 that is what _changedData gets updated with. So, smth like 705 ``(params=['C'], labels=True)`` if parameter C and labels 706 got changed 707 """ 708 # Note that it also demolishes anything for repredicting, 709 # which should be ok in most of the cases 710 if __debug__: 711 if not self.params.retrainable: 712 raise RuntimeError, \ 713 "Do not use re(train,predict) on non-retrainable %s" % \ 714 self 715 716 if kwargs.has_key('params') or kwargs.has_key('kernel_params'): 717 raise ValueError, \ 718 "Retraining for changed params not working yet" 719 720 self.__resetChangedData() 721 722 # local bindings 723 chd = self._changedData 724 ichd = self.__invalidatedChangedData 725 726 chd.update(kwargs) 727 # mark for future 'train()' items which are explicitely 728 # mentioned as changed 729 for key, value in kwargs.iteritems(): 730 if value: 731 ichd[key] = True 732 self.__changedData_isset = True 733 734 # To check if we are not fooled 735 if __debug__ and 'CHECK_RETRAIN' in debug.active: 736 for key, data_ in (('traindata', dataset.samples), 737 ('labels', dataset.labels)): 738 # so it wasn't told to be invalid 739 if not chd[key] and not ichd.get(key, False): 740 if self.__wasDataChanged(key, data_, update=False): 741 raise RuntimeError, \ 742 "Data %s found changed although wasn't " \ 743 "labeled as such" % key 744 745 # TODO: parameters of classifiers... for now there is explicit 746 # 'forbidance' above 747 748 # Below check should be superseeded by check above, thus never occur. 749 # remove later on ??? 750 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \ 751 and not self._changedData['traindata'] \ 752 and self.__trained['traindata'].shape != dataset.samples.shape: 753 raise ValueError, "In retrain got dataset with %s size, " \ 754 "whenever previousely was trained on %s size" \ 755 % (dataset.samples.shape, self.__trained['traindata'].shape) 756 self.train(dataset)
757 758
759 - def repredict(self, data, **kwargs):
760 """Helper to avoid check if data was changed actually changed 761 762 Useful if classifier was (re)trained but with the same data 763 (so just parameters were changed), so that it could be 764 repredicted easily (on the same data as before) without 765 recomputing for instance train/test kernel matrix. Should be 766 used with caution and always compared to the results on not 767 'retrainable' classifier. Some additional checks are enabled 768 if debug id 'CHECK_RETRAIN' is enabled, to guard against 769 obvious mistakes. 770 771 :Parameters: 772 data 773 data which is conventionally given to predict 774 kwargs 775 that is what _changedData gets updated with. So, smth like 776 ``(params=['C'], labels=True)`` if parameter C and labels 777 got changed 778 """ 779 if len(kwargs)>0: 780 raise RuntimeError, \ 781 "repredict for now should be used without params since " \ 782 "it makes little sense to repredict if anything got changed" 783 if __debug__ and not self.params.retrainable: 784 raise RuntimeError, \ 785 "Do not use retrain/repredict on non-retrainable classifiers" 786 787 self.__resetChangedData() 788 chd = self._changedData 789 chd.update(**kwargs) 790 self.__changedData_isset = True 791 792 793 # check if we are attempted to perform on the same data 794 if __debug__ and 'CHECK_RETRAIN' in debug.active: 795 for key, data_ in (('testdata', data),): 796 # so it wasn't told to be invalid 797 #if not chd[key]:# and not ichd.get(key, False): 798 if self.__wasDataChanged(key, data_, update=False): 799 raise RuntimeError, \ 800 "Data %s found changed although wasn't " \ 801 "labeled as such" % key 802 803 # Should be superseded by above 804 # remove in future??? 805 if __debug__ and 'CHECK_RETRAIN' in debug.active \ 806 and not self._changedData['testdata'] \ 807 and self.__trained['testdata'].shape != data.shape: 808 raise ValueError, "In repredict got dataset with %s size, " \ 809 "whenever previously was trained on %s size" \ 810 % (data.shape, self.__trained['testdata'].shape) 811 812 return self.predict(data)
813 814 815 # TODO: callback into retrainable parameter 816 #retrainable = property(fget=_getRetrainable, fset=_setRetrainable, 817 # doc="Specifies either classifier should be retrainable") 818