1
2
3
4
5
6
7
8
9 """Unit tests for SVM classifier"""
10
11 from sets import Set
12
13 from mvpa.datasets.splitters import NFoldSplitter
14 from mvpa.clfs.meta import ProxyClassifier
15 from mvpa.clfs.transerror import TransferError
16 from mvpa.algorithms.cvtranserror import CrossValidatedTransferError
17
18 from tests_warehouse import pureMultivariateSignal
19 from tests_warehouse import *
20 from tests_warehouse_clfs import *
21
23
24
25
27 mv_perf = []
28 mv_lin_perf = []
29 uv_perf = []
30
31 l_clf = clfswh['linear', 'svm'][0]
32 nl_clf = clfswh['non-linear', 'svm'][0]
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 import mvpa.support.copy as copy
53 try:
54 nl_clf.untrain()
55 nl_clf_copy = copy.deepcopy(nl_clf)
56 except:
57 self.fail(msg="Failed to deepcopy not-yet trained SVM %s" % nl_clf)
58
59 for i in xrange(20):
60 train = pureMultivariateSignal( 20, 3 )
61 test = pureMultivariateSignal( 20, 3 )
62
63
64 nl_clf.train(train)
65 p_mv = nl_clf.predict(test.samples)
66 mv_perf.append(N.mean(p_mv==test.labels))
67
68
69 l_clf.train(train)
70 p_lin_mv = l_clf.predict(test.samples)
71 mv_lin_perf.append(N.mean(p_lin_mv==test.labels))
72
73
74 nl_clf.train(train.selectFeatures([0]))
75 p_uv = nl_clf.predict(test.selectFeatures([0]).samples)
76 uv_perf.append(N.mean(p_uv==test.labels))
77
78 mean_mv_perf = N.mean(mv_perf)
79 mean_mv_lin_perf = N.mean(mv_lin_perf)
80 mean_uv_perf = N.mean(uv_perf)
81
82
83 self.failUnless( mean_mv_perf > 0.9 )
84
85 self.failUnless( mean_mv_perf > mean_mv_lin_perf )
86
87 self.failUnless( mean_uv_perf < mean_mv_perf )
88
89
90
91
92
93
94
95
96 @sweepargs(clf=clfswh['svm', 'sg', '!regression', '!gnpp', '!meta'])
98 try:
99 if clf.C > 0:
100
101 return
102 except:
103
104 return
105
106 if clf.C < -5:
107
108
109 return
110
111 ds = datasets['uni2small'].copy()
112 ds__ = datasets['uni2small'].copy()
113
114
115
116
117 ds__.samples = ds__.samples + 0.5 * N.random.normal(size=(ds__.samples.shape))
118
119
120
121 times = 10
122 ds_ = ds.selectSamples(range(ds.nsamples) + range(ds.nsamples/2) * times)
123 ds_.samples = ds_.samples + 0.7 * N.random.normal(size=(ds_.samples.shape))
124 spl = ds_.samplesperlabel
125
126
127 cve = CrossValidatedTransferError(TransferError(clf), NFoldSplitter(),
128 enable_states='confusion')
129 e = cve(ds__)
130 if cfg.getboolean('tests', 'labile', default='yes'):
131
132 self.failUnless(cve.confusion.stats["P'"][1] > 0)
133
134 e = cve(ds_)
135 if cfg.getboolean('tests', 'labile', default='yes'):
136 self.failUnless(cve.confusion.stats["P'"][1] < 5,
137 msg="With disballance we should have almost no "
138 "hits. Got %f" % cve.confusion.stats["P'"][1])
139
140
141
142 oldC = clf.C
143 ratio = N.sqrt(float(spl[0])/spl[1])
144 clf.C = (-1/ratio, -1*ratio)
145 try:
146 e_ = cve(ds_)
147
148 clf.C = oldC
149 except:
150 clf.C = oldC
151 raise
152
153 if cfg.getboolean('tests', 'labile', default='yes'):
154
155
156 self.failUnless(cve.confusion.stats["P'"][1] > 0)
157
158
160 """Test if we raise exceptions on incorrect specifications
161 """
162
163 if externals.exists('libsvm') or externals.exists('shogun'):
164 self.failUnlessRaises(TypeError, SVM, C=1.0, nu=2.3)
165
166 if externals.exists('libsvm'):
167 self.failUnlessRaises(TypeError, libsvm.SVM, C=1.0, nu=2.3)
168 self.failUnlessRaises(TypeError, LinearNuSVMC, C=2.3)
169 self.failUnlessRaises(TypeError, LinearCSVMC, nu=2.3)
170
171 if externals.exists('shogun'):
172 self.failUnlessRaises(TypeError, sg.SVM, C=10, kernel_type='RBF',
173 coef0=3)
174
177
178
179 if __name__ == '__main__':
180 import runner
181