Package mvpa :: Package tests :: Module test_params
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_params

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA Parameter class.""" 
 10   
 11  import unittest, copy 
 12   
 13  import numpy as N 
 14  from sets import Set 
 15   
 16  from mvpa.datasets import Dataset 
 17  from mvpa.misc.state import ClassWithCollections, StateVariable 
 18  from mvpa.misc.param import Parameter, KernelParameter 
 19   
 20  from tests_warehouse_clfs import SameSignClassifier 
 21   
22 -class ParametrizedClassifier(SameSignClassifier):
23 p1 = Parameter(1.0) 24 kp1 = KernelParameter(100.0)
25
26 -class ParametrizedClassifierExtended(ParametrizedClassifier):
27 - def __init__(self):
28 ParametrizedClassifier.__init__(self) 29 self.kernel_params.add(KernelParameter(200.0, doc="Very useful param", name="kp2"))
30
31 -class BlankClass(ClassWithCollections):
32 pass
33
34 -class SimpleClass(ClassWithCollections):
35 C = Parameter(1.0, min=0, doc="C parameter")
36
37 -class MixedClass(ClassWithCollections):
38 C = Parameter(1.0, min=0, doc="C parameter") 39 D = Parameter(3.0, min=0, doc="D parameter") 40 state1 = StateVariable(doc="bogus")
41
42 -class ParamsTests(unittest.TestCase):
43
44 - def testBlank(self):
45 blank = BlankClass() 46 47 self.failUnlessRaises(AttributeError, blank.__getattribute__, 'states') 48 self.failUnlessRaises(IndexError, blank.__getattribute__, '')
49
50 - def testSimple(self):
51 simple = SimpleClass() 52 53 self.failUnlessEqual(len(simple.params.items), 1) 54 self.failUnlessRaises(AttributeError, simple.__getattribute__, 'dummy') 55 self.failUnlessRaises(IndexError, simple.__getattribute__, '') 56 57 self.failUnlessEqual(simple.C, 1.0) 58 self.failUnlessEqual(simple.params.isSet("C"), False) 59 self.failUnlessEqual(simple.params.isSet(), False) 60 self.failUnlessEqual(simple.params["C"].isDefault, True) 61 self.failUnlessEqual(simple.params["C"].equalDefault, True) 62 63 simple.C = 1.0 64 # we are not actually setting the value if == default 65 self.failUnlessEqual(simple.params["C"].isDefault, True) 66 self.failUnlessEqual(simple.params["C"].equalDefault, True) 67 68 simple.C = 10.0 69 self.failUnlessEqual(simple.params.isSet("C"), True) 70 self.failUnlessEqual(simple.params.isSet(), True) 71 self.failUnlessEqual(simple.params["C"].isDefault, False) 72 self.failUnlessEqual(simple.params["C"].equalDefault, False) 73 74 self.failUnlessEqual(simple.C, 10.0) 75 simple.params["C"].resetvalue() 76 self.failUnlessEqual(simple.params.isSet("C"), True) 77 # TODO: Test if we 'train' a classifier f we get isSet to false 78 self.failUnlessEqual(simple.C, 1.0) 79 self.failUnlessRaises(AttributeError, simple.params.__getattribute__, 'B')
80
81 - def testMixed(self):
82 mixed = MixedClass() 83 84 self.failUnlessEqual(len(mixed.params.items), 2) 85 self.failUnlessEqual(len(mixed.states.items), 1) 86 self.failUnlessRaises(AttributeError, mixed.__getattribute__, 'kernel_params') 87 88 self.failUnlessEqual(mixed.C, 1.0) 89 self.failUnlessEqual(mixed.params.isSet("C"), False) 90 self.failUnlessEqual(mixed.params.isSet(), False) 91 mixed.C = 10.0 92 self.failUnlessEqual(mixed.params.isSet("C"), True) 93 self.failUnlessEqual(mixed.params.isSet("D"), False) 94 self.failUnlessEqual(mixed.params.isSet(), True) 95 self.failUnlessEqual(mixed.D, 3.0)
96 97
98 - def testClassifier(self):
99 clf = ParametrizedClassifier() 100 self.failUnlessEqual(len(clf.params.items), 3) # + regression/retrainable 101 self.failUnlessEqual(len(clf.kernel_params.items), 1) 102 103 clfe = ParametrizedClassifierExtended() 104 self.failUnlessEqual(len(clfe.params.items), 3) 105 self.failUnlessEqual(len(clfe.kernel_params.items), 2) 106 self.failUnlessEqual(len(clfe.kernel_params.listing), 2) 107 108 # check assignment once again 109 self.failUnlessEqual(clfe.kp2, 200.0) 110 clfe.kp2 = 201.0 111 self.failUnlessEqual(clfe.kp2, 201.0) 112 self.failUnlessEqual(clfe.kernel_params.isSet("kp2"), True) 113 clfe.train(Dataset(samples=[[0,0]], labels=[1], chunks=[1])) 114 self.failUnlessEqual(clfe.kernel_params.isSet("kp2"), False) 115 self.failUnlessEqual(clfe.kernel_params.isSet(), False) 116 self.failUnlessEqual(clfe.params.isSet(), False)
117
118 -def suite():
119 return unittest.makeSuite(ParamsTests)
120 121 122 if __name__ == '__main__': 123 import runner 124