1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA classifier cross-validation"""
10
11 import unittest
12 from mvpa.support.copy import copy
13
14 from mvpa.base import externals
15 from mvpa.datasets import Dataset
16 from mvpa.datasets.splitters import OddEvenSplitter
17
18 from mvpa.clfs.meta import MulticlassClassifier
19 from mvpa.clfs.transerror import \
20 TransferError, ConfusionMatrix, ConfusionBasedError
21 from mvpa.algorithms.cvtranserror import CrossValidatedTransferError
22
23 from mvpa.clfs.stats import MCNullDist
24
25 from mvpa.misc.exceptions import UnknownStateError
26
27 from tests_warehouse import datasets, sweepargs
28 from tests_warehouse_clfs import *
29
31
33 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T
34 reg = [1,1,1,2,2,2,3,3,3]
35 regl = [1,2,1,2,2,2,3,2,1]
36 correct_cm = [[2,0,1],[1,3,1],[0,0,1]]
37
38 for t in [reg, tuple(reg), list(reg), N.array(reg)]:
39 for p in [regl, tuple(regl), list(regl), N.array(regl)]:
40 cm = ConfusionMatrix(targets=t, predictions=p)
41
42 self.failUnless((cm.matrix == correct_cm).all())
43
44
45
46 cm = ConfusionMatrix()
47 self.failUnlessRaises(ZeroDivisionError, lambda x:x.percentCorrect, cm)
48 """No samples -- raise exception"""
49
50 cm.add(reg, regl)
51
52 self.failUnlessEqual(len(cm.sets), 1,
53 msg="Should have a single set so far")
54 self.failUnlessEqual(cm.matrix.shape, (3,3),
55 msg="should be square matrix (len(reglabels) x len(reglabels)")
56
57 self.failUnlessRaises(ValueError, cm.add, reg, N.array([1]))
58 """ConfusionMatrix must complaint if number of samples different"""
59
60
61 self.failUnless((cm.matrix == correct_cm).all())
62
63
64 cm.add(reg, N.array([1,4,1,2,2,2,4,2,1]))
65
66 self.failUnlessEqual(cm.labels, [1,2,3,4],
67 msg="We should have gotten 4th label")
68
69 matrices = cm.matrices
70 self.failUnlessEqual(len(matrices), 2,
71 msg="Have gotten two splits")
72
73 self.failUnless((matrices[0].matrix + matrices[1].matrix == cm.matrix).all(),
74 msg="Total votes should match the sum across split CMs")
75
76
77
78 self.failUnless(len(cm.asstring(
79 header=True, summary=True,
80 description=True))>100)
81 self.failUnless(len(str(cm))>100)
82
83 self.failUnless(len(cm.asstring(summary=True,
84 header=False))>100)
85
86
87 cm += cm
88 self.failUnlessEqual(len(cm.matrices), 4, msg="Must be 4 sets now")
89
90
91 cm2 = cm + cm
92 self.failUnlessEqual(len(cm2.matrices), 8, msg="Must be 8 sets now")
93 self.failUnlessEqual(cm2.percentCorrect, cm.percentCorrect,
94 msg="Percent of corrrect should remain the same ;-)")
95
96 self.failUnlessEqual(cm2.error, 1.0-cm.percentCorrect/100.0,
97 msg="Test if we get proper error value")
98
99
101
102
103
104 for orig in ([1], [1, 1], [0], [0, 0]):
105 cm = ConfusionMatrix(targets=orig, predictions=orig, values=orig)
106
107 scm = str(cm)
108 self.failUnless(cm.stats['ACC%'] == 100)
109
110
112 reg = [0,0,1,1]
113 regl = [1,0,1,0]
114 cm = ConfusionMatrix(targets=reg, predictions=regl)
115 self.failUnless('ACC% 50' in str(cm))
116
117
119 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T
120 reg = [1,1,1,2,2,2,3,3,3]
121 regl = [1,2,1,2,2,2,3,2,1]
122 correct_cm = [[2,0,1], [1,3,1], [0,0,1]]
123 lm = {'apple':1, 'orange':2, 'shitty apple':1, 'candy':3}
124 cm = ConfusionMatrix(targets=reg, predictions=regl,
125 labels_map=lm)
126
127 self.failUnless((cm.matrix == correct_cm).all())
128
129 s = str(cm)
130 for l in lm.keys():
131 self.failUnless(l in s)
132
133
134
135 @sweepargs(l_clf=clfswh['linear', 'svm'])
137 train = datasets['uni2medium_train']
138
139 test3 = datasets['uni3medium_train']
140 err = ConfusionBasedError(clf=l_clf)
141 terr = TransferError(clf=l_clf)
142
143 self.failUnlessRaises(UnknownStateError, err, None)
144 """Shouldn't be able to access the state yet"""
145
146 l_clf.train(train)
147 self.failUnlessEqual(err(None), terr(train),
148 msg="ConfusionBasedError should be equal to TransferError on" +
149 " traindataset")
150
151
152
153 self.failIf(terr(test3) is None)
154
155
156 terr_copy = copy(terr)
157
158
159 @sweepargs(l_clf=clfswh['linear', 'svm'])
181
182
183 @sweepargs(l_clf=clfswh['linear', 'svm'])
197
198
199 @sweepargs(clf=clfswh['multiclass'])
201 """Test AUC computation
202 """
203 if isinstance(clf, MulticlassClassifier):
204
205 return
206 clf.states._changeTemporarily(enable_states = ['values'])
207
208 ds2 = datasets['uni2small'].copy()
209 ds2.labels = 1 - ds2.labels
210
211 ds3 = datasets['uni3small'].copy()
212 ul = ds3.uniquelabels
213 nl = ds3.labels.copy()
214 for l in xrange(3):
215 nl[ds3.labels == ul[l]] = ul[(l+1)%3]
216 ds3.labels = nl
217 for ds in [datasets['uni2small'], ds2,
218 datasets['uni3small'], ds3]:
219 cv = CrossValidatedTransferError(
220 TransferError(clf),
221 OddEvenSplitter(),
222 enable_states=['confusion', 'training_confusion'])
223 cverror = cv(ds)
224 stats = cv.confusion.stats
225 Nlabels = len(ds.uniquelabels)
226
227 self.failUnless(stats['ACC'] > 1.2 / Nlabels)
228 auc = stats['AUC']
229 if (Nlabels == 2) or (Nlabels > 2 and auc[0] is not N.nan):
230 mauc = N.min(stats['AUC'])
231 if cfg.getboolean('tests', 'labile', default='yes'):
232 self.failUnless(mauc > 0.55,
233 msg='All AUCs must be above chance. Got minimal '
234 'AUC=%.2g among %s' % (mauc, stats['AUC']))
235 clf.states._resetEnabledTemporarily()
236
237
238
239
241 """Based on existing cell dataset results.
242
243 Let in for possible future testing, but is not a part of the
244 unittests suite
245 """
246
247
248
249
250
251
252
253
254
255 array = N.array
256 uint8 = N.uint8
257 sets = [
258 (array([47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
259 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
260 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
261 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
262 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
263 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
264 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
265 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
266 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
267 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
268 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
269 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
270 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44], dtype=uint8),
271 array([40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44,
272 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46,
273 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38,
274 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39,
275 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45,
276 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38,
277 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46,
278 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41,
279 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40,
280 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41,
281 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44,
282 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43,
283 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44], dtype=uint8)),
284 (array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
285 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
286 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
287 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
288 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
289 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
290 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
291 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
292 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
293 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
294 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
295 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
296 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8),
297 array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43,
298 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
299 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40,
300 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45,
301 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40,
302 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46,
303 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41,
304 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39,
305 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42,
306 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44,
307 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43,
308 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39,
309 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8)),
310 (array([45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
311 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
312 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
313 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
314 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
315 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
316 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
317 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
318 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
319 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
320 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
321 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
322 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47], dtype=uint8),
323 array([45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47,
324 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40,
325 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45,
326 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47,
327 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40,
328 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
329 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38,
330 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42,
331 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44,
332 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38,
333 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39,
334 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40,
335 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40], dtype=uint8)),
336 (array([39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
337 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
338 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
339 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
340 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
341 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
342 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
343 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
344 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
345 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
346 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46,
347 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
348 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40], dtype=uint8),
349 array([39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40,
350 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45,
351 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39,
352 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46,
353 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40,
354 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38,
355 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42,
356 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39,
357 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43,
358 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47,
359 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46,
360 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41,
361 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40], dtype=uint8)),
362 (array([46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
363 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
364 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
365 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
366 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
367 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
368 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
369 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
370 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
371 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
372 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
373 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
374 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8),
375 array([46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45,
376 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47,
377 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46,
378 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41,
379 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39,
380 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42,
381 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44,
382 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43,
383 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46,
384 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38,
385 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45,
386 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40,
387 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45], dtype=uint8)),
388 (array([41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
389 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
390 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
391 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
392 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
393 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
394 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
395 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
396 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
397 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
398 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
399 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
400 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39], dtype=uint8),
401 array([41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40,
402 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38,
403 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41,
404 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39,
405 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42,
406 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44,
407 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43,
408 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40,
409 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40,
410 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45,
411 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39,
412 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46,
413 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38], dtype=uint8)),
414 (array([38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
415 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
416 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
417 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
418 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
419 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
420 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
421 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
422 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
423 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
424 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
425 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
426 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46], dtype=uint8),
427 array([39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46,
428 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41,
429 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38,
430 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42,
431 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
432 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43,
433 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39,
434 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40,
435 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45,
436 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39,
437 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46,
438 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41,
439 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46], dtype=uint8)),
440 (array([42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
441 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
442 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
443 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
444 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
445 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
446 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
447 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
448 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
449 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
450 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
451 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
452 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8),
453 array([42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41,
454 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46,
455 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42,
456 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44,
457 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43,
458 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47,
459 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41,
460 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45,
461 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38,
462 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46,
463 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41,
464 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38,
465 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45], dtype=uint8))]
466 labels_map = {'12kHz': 40,
467 '20kHz': 41,
468 '30kHz': 42,
469 '3kHz': 38,
470 '7kHz': 39,
471 'song1': 43,
472 'song2': 44,
473 'song3': 45,
474 'song4': 46,
475 'song5': 47}
476 try:
477 cm = ConfusionMatrix(sets=sets, labels_map=labels_map)
478 except:
479 self.fail()
480 self.failUnless('3kHz / 38' in cm.asstring())
481
482 if externals.exists("pylab plottable"):
483 import pylab as P
484 P.figure()
485 labels_order = ("3kHz", "7kHz", "12kHz", "20kHz","30kHz", None,
486 "song1","song2","song3","song4","song5")
487
488
489 fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1]
490 + labels_order[2:], numbers=True)
491 self.failUnless(cm._plotted_confusionmatrix[0,0] == cm.matrix[1,1])
492 self.failUnless(cm._plotted_confusionmatrix[0,1] == cm.matrix[1,0])
493 self.failUnless(cm._plotted_confusionmatrix[1,1] == cm.matrix[0,0])
494 self.failUnless(cm._plotted_confusionmatrix[1,0] == cm.matrix[0,1])
495 P.close(fig)
496 fig, im, cb = cm.plot(labels=labels_order, numbers=True)
497 P.close(fig)
498
499
501 """Based on a sample confusion which plots incorrectly
502
503 """
504
505 array = N.array
506 uint8 = N.uint8
507 sets = [(array([1, 2]), array([1, 1]),
508 array([[ 0.54343765, 0.45656235],
509 [ 0.92395853, 0.07604147]])),
510 (array([1, 2]), array([1, 1]),
511 array([[ 0.98030832, 0.01969168],
512 [ 0.78998763, 0.21001237]])),
513 (array([1, 2]), array([1, 1]),
514 array([[ 0.86125263, 0.13874737],
515 [ 0.83674113, 0.16325887]])),
516 (array([1, 2]), array([1, 1]),
517 array([[ 0.57870383, 0.42129617],
518 [ 0.59702509, 0.40297491]])),
519 (array([1, 2]), array([1, 1]),
520 array([[ 0.89530255, 0.10469745],
521 [ 0.69373919, 0.30626081]])),
522 (array([1, 2]), array([1, 1]),
523 array([[ 0.75015218, 0.24984782],
524 [ 0.9339767 , 0.0660233 ]])),
525 (array([1, 2]), array([1, 2]),
526 array([[ 0.97826616, 0.02173384],
527 [ 0.38620638, 0.61379362]])),
528 (array([2]), array([2]),
529 array([[ 0.46893776, 0.53106224]]))]
530 try:
531 cm = ConfusionMatrix(sets=sets)
532 except:
533 self.fail()
534 if externals.exists("pylab plottable"):
535 import pylab as P
536
537
538 fig, im, cb = cm.plot(origin='lower', numbers=True)
539
540 self.failUnless((cm._plotted_confusionmatrix == cm.matrix).all())
541 P.close(fig)
542
543
544
545
546
549
550
551 if __name__ == '__main__':
552 import runner
553