1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA pattern handling"""
10
11 from mvpa.datasets.masked import MaskedDataset
12 from mvpa.datasets.splitters import NFoldSplitter, OddEvenSplitter, \
13 NoneSplitter, HalfSplitter, \
14 CustomSplitter, NGroupSplitter
15 import unittest
16 import numpy as N
17
18
20
22 self.data = \
23 MaskedDataset(samples=N.random.normal(size=(100,10)),
24 labels=[ i%4 for i in range(100) ],
25 chunks=[ i/10 for i in range(100)])
26
27
29
30 nfs = NFoldSplitter(cvtype=1)
31
32
33 xvpat = [ (train, test) for (train,test) in nfs(self.data) ]
34
35 self.failUnless( len(xvpat) == 10 )
36
37 for i,p in enumerate(xvpat):
38 self.failUnless( len(p) == 2 )
39 self.failUnless( p[0].nsamples == 90 )
40 self.failUnless( p[1].nsamples == 10 )
41 self.failUnless( p[1].chunks[0] == i )
42
43
45 oes = OddEvenSplitter()
46
47 splits = [ (train, test) for (train, test) in oes(self.data) ]
48
49 self.failUnless(len(splits) == 2)
50
51 for i,p in enumerate(splits):
52 self.failUnless( len(p) == 2 )
53 self.failUnless( p[0].nsamples == 50 )
54 self.failUnless( p[1].nsamples == 50 )
55
56 self.failUnless((splits[0][1].uniquechunks == [1, 3, 5, 7, 9]).all())
57 self.failUnless((splits[0][0].uniquechunks == [0, 2, 4, 6, 8]).all())
58 self.failUnless((splits[1][0].uniquechunks == [1, 3, 5, 7, 9]).all())
59 self.failUnless((splits[1][1].uniquechunks == [0, 2, 4, 6, 8]).all())
60
61
62 moresplits = [ (train, test) for (train, test) in oes(splits[0][0])]
63
64 for split in moresplits:
65 self.failUnless(split[0] != None)
66 self.failUnless(split[1] != None)
67
68
70 hs = HalfSplitter()
71
72 splits = [ (train, test) for (train, test) in hs(self.data) ]
73
74 self.failUnless(len(splits) == 2)
75
76 for i,p in enumerate(splits):
77 self.failUnless( len(p) == 2 )
78 self.failUnless( p[0].nsamples == 50 )
79 self.failUnless( p[1].nsamples == 50 )
80
81 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all())
82 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all())
83 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all())
84 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all())
85
86
87 moresplits = [ (train, test) for (train, test) in hs(splits[0][0])]
88
89 for split in moresplits:
90 self.failUnless(split[0] != None)
91 self.failUnless(split[1] != None)
92
94 """Test NGroupSplitter alongside with the reversal of the
95 order of spit out datasets
96 """
97
98 hs = NGroupSplitter(2)
99 hs_reversed = NGroupSplitter(2, reverse=True)
100
101 for isreversed, splitter in enumerate((hs, hs_reversed)):
102 splits = list(splitter(self.data))
103 self.failUnless(len(splits) == 2)
104
105 for i, p in enumerate(splits):
106 self.failUnless( len(p) == 2 )
107 self.failUnless( p[0].nsamples == 50 )
108 self.failUnless( p[1].nsamples == 50 )
109
110 self.failUnless((splits[0][1-isreversed].uniquechunks == [0, 1, 2, 3, 4]).all())
111 self.failUnless((splits[0][isreversed].uniquechunks == [5, 6, 7, 8, 9]).all())
112 self.failUnless((splits[1][1-isreversed].uniquechunks == [5, 6, 7, 8, 9]).all())
113 self.failUnless((splits[1][isreversed].uniquechunks == [0, 1, 2, 3, 4]).all())
114
115
116 moresplits = list(hs(splits[0][0]))
117
118 for split in moresplits:
119 self.failUnless(split[0] != None)
120 self.failUnless(split[1] != None)
121
122
123 s5 = NGroupSplitter(5)
124 s5_reversed = NGroupSplitter(5, reverse=True)
125
126
127 for isreversed, s5splitter in enumerate((s5, s5_reversed)):
128 splits = list(s5splitter(self.data))
129
130
131 self.failUnless(len(splits) == 5)
132
133
134 self.failUnless((splits[0][1-isreversed].uniquechunks == [0, 1]).all())
135 self.failUnless((splits[0][isreversed].uniquechunks == [2, 3, 4, 5, 6, 7, 8, 9]).all())
136 self.failUnless((splits[1][1-isreversed].uniquechunks == [2, 3]).all())
137 self.failUnless((splits[1][isreversed].uniquechunks == [0, 1, 4, 5, 6, 7, 8, 9]).all())
138
139 self.failUnless((splits[4][1-isreversed].uniquechunks == [8, 9]).all())
140 self.failUnless((splits[4][isreversed].uniquechunks == [0, 1, 2, 3, 4, 5, 6, 7]).all())
141
142
143
144 def splitcall(spl, dat):
145 return [ (train, test) for (train, test) in spl(dat) ]
146 s20 = NGroupSplitter(20)
147 self.assertRaises(ValueError,splitcall,s20,self.data)
148
150
151 hs = CustomSplitter([(None,[0,1,2,3,4]),(None,[5,6,7,8,9])])
152 splits = list(hs(self.data))
153 self.failUnless(len(splits) == 2)
154
155 for i,p in enumerate(splits):
156 self.failUnless( len(p) == 2 )
157 self.failUnless( p[0].nsamples == 50 )
158 self.failUnless( p[1].nsamples == 50 )
159
160 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all())
161 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all())
162 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all())
163 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all())
164
165
166
167 cs = CustomSplitter([([0,3,4],[5,9])])
168 splits = list(cs(self.data))
169 self.failUnless(len(splits) == 1)
170
171 for i,p in enumerate(splits):
172 self.failUnless( len(p) == 2 )
173 self.failUnless( p[0].nsamples == 30 )
174 self.failUnless( p[1].nsamples == 20 )
175
176 self.failUnless((splits[0][1].uniquechunks == [5, 9]).all())
177 self.failUnless((splits[0][0].uniquechunks == [0, 3, 4]).all())
178
179
180 cs = CustomSplitter([([0,3,4],[5,9],[2])],
181 nperlabel=[3,4,1],
182 nrunspersplit=3)
183 splits = list(cs(self.data))
184 self.failUnless(len(splits) == 3)
185
186 for i,p in enumerate(splits):
187 self.failUnless( len(p) == 3 )
188 self.failUnless( p[0].nsamples == 12 )
189 self.failUnless( p[1].nsamples == 16 )
190 self.failUnless( p[2].nsamples == 4 )
191
192
193
194 cs = CustomSplitter([([0,3,4],[5,9],[2])],
195 nperlabel=[[0.3, 0.6, 1.0, 0.5],
196 0.5,
197 'all'],
198 nrunspersplit=3)
199 csall = CustomSplitter([([0,3,4],[5,9],[2])],
200 nrunspersplit=3)
201
202
203 splits = list(cs(self.data))
204 splitsall = list(csall(self.data))
205
206 self.failUnless(len(splits) == 3)
207 ul = self.data.uniquelabels
208
209 self.failUnless(((N.array(splitsall[0][0].samplesperlabel.values())
210 *[0.3, 0.6, 1.0, 0.5]).round().astype(int) ==
211 N.array(splits[0][0].samplesperlabel.values())).all())
212
213 self.failUnless(((N.array(splitsall[0][1].samplesperlabel.values())*0.5
214 ).round().astype(int) ==
215 N.array(splits[0][1].samplesperlabel.values())).all())
216
217 self.failUnless((N.array(splitsall[0][2].samplesperlabel.values()) ==
218 N.array(splits[0][2].samplesperlabel.values())).all())
219
220
222 nos = NoneSplitter()
223 splits = [ (train, test) for (train, test) in nos(self.data) ]
224 self.failUnless(len(splits) == 1)
225 self.failUnless(splits[0][0] == None)
226 self.failUnless(splits[0][1].nsamples == 100)
227
228 nos = NoneSplitter(mode='first')
229 splits = [ (train, test) for (train, test) in nos(self.data) ]
230 self.failUnless(len(splits) == 1)
231 self.failUnless(splits[0][1] == None)
232 self.failUnless(splits[0][0].nsamples == 100)
233
234
235
236
237 nos = NoneSplitter(nrunspersplit=3,
238 nperlabel=10)
239 splits = [ (train, test) for (train, test) in nos(self.data) ]
240
241 self.failUnless(len(splits) == 3)
242 for split in splits:
243 self.failUnless(split[0] == None)
244 self.failUnless(split[1].nsamples == 40)
245 self.failUnless(split[1].samplesperlabel.values() == [10,10,10,10])
246
247
248 nos = NoneSplitter(nrunspersplit=3,
249 nperlabel='equal')
250 splits = [ (train, test) for (train, test) in nos(self.data) ]
251
252 self.failUnless(len(splits) == 3)
253 for split in splits:
254 self.failUnless(split[0] == None)
255 self.failUnless(split[1].nsamples == 100)
256 self.failUnless(split[1].samplesperlabel.values() == [25,25,25,25])
257
258
260 oes = OddEvenSplitter(attr='labels')
261
262 splits = [ (first, second) for (first, second) in oes(self.data) ]
263
264 self.failUnless((splits[0][0].uniquelabels == [0,2]).all())
265 self.failUnless((splits[0][1].uniquelabels == [1,3]).all())
266 self.failUnless((splits[1][0].uniquelabels == [1,3]).all())
267 self.failUnless((splits[1][1].uniquelabels == [0,2]).all())
268
269
271
272 nchunks = len(self.data.uniquechunks)
273 for strategy in NFoldSplitter._STRATEGIES:
274 for count, target in [ (nchunks*2, nchunks),
275 (nchunks, nchunks),
276 (nchunks-1, nchunks-1),
277 (3, 3),
278 (0, 0),
279 (1, 1)
280 ]:
281 nfs = NFoldSplitter(cvtype=1, count=count, strategy=strategy)
282 splits = [ (train, test) for (train,test) in nfs(self.data) ]
283 self.failUnless(len(splits) == target)
284 chosenchunks = [int(s[1].uniquechunks) for s in splits]
285 if strategy == 'first':
286 self.failUnlessEqual(chosenchunks, range(target))
287 elif strategy == 'equidistant':
288 if target == 3:
289 self.failUnlessEqual(chosenchunks, [0, 3, 7])
290 elif strategy == 'random':
291
292 self.failUnless(len(set(chosenchunks)) == len(chosenchunks))
293 self.failUnless(target == len(chosenchunks))
294 else:
295 raise RuntimeError, "Add unittest for strategy %s" \
296 % strategy
297
298
300 splitters = [NFoldSplitter(),
301 NFoldSplitter(discard_boundary=(0,1)),
302 NFoldSplitter(discard_boundary=(1,0)),
303 NFoldSplitter(discard_boundary=(2,0)),
304 NFoldSplitter(discard_boundary=1),
305 OddEvenSplitter(discard_boundary=(1,0)),
306 OddEvenSplitter(discard_boundary=(0,1)),
307 HalfSplitter(discard_boundary=(1,0)),
308 ]
309
310 split_sets = [list(s(self.data)) for s in splitters]
311 counts = [[(len(s[0].chunks), len(s[1].chunks)) for s in split_set]
312 for split_set in split_sets]
313
314 nodiscard_tr = [c[0] for c in counts[0]]
315 nodiscard_te = [c[1] for c in counts[0]]
316
317
318 self.failUnless(nodiscard_tr == [c[0] for c in counts[1]])
319 self.failUnless(nodiscard_te[1:-1] == [c[1] + 2 for c in counts[1][1:-1]])
320
321 self.failUnless(nodiscard_te[0] == counts[1][0][1] + 1)
322 self.failUnless(nodiscard_te[-1] == counts[1][-1][1] + 1)
323
324
325 for d in [1,2]:
326 self.failUnless(nodiscard_te == [c[1] for c in counts[1+d]])
327 self.failUnless(nodiscard_tr[0] == counts[1+d][0][0] + d)
328 self.failUnless(nodiscard_tr[-1] == counts[1+d][-1][0] + d)
329 self.failUnless(nodiscard_tr[1:-1] == [c[0] + d*2
330 for c in counts[1+d][1:-1]])
331
332
333 counts_min = [(min(c1[0], c2[0]), min(c1[1], c2[1]))
334 for c1,c2 in zip(counts[1], counts[2])]
335 self.failUnless(counts_min == counts[4])
336
337
338
339
340
341
342
345
346
347 if __name__ == '__main__':
348 import runner
349