1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA recursive feature elimination"""
10
11 from sets import Set
12
13 from mvpa.datasets.splitters import NFoldSplitter
14 from mvpa.algorithms.cvtranserror import CrossValidatedTransferError
15 from mvpa.datasets.masked import MaskedDataset
16 from mvpa.measures.base import FeaturewiseDatasetMeasure
17 from mvpa.featsel.rfe import RFE
18 from mvpa.featsel.base import \
19 SensitivityBasedFeatureSelection, \
20 FeatureSelectionPipeline
21 from mvpa.featsel.helpers import \
22 NBackHistoryStopCrit, FractionTailSelector, FixedErrorThresholdStopCrit, \
23 MultiStopCrit, NStepsStopCrit, \
24 FixedNElementTailSelector, BestDetector, RangeElementSelector
25
26 from mvpa.clfs.meta import FeatureSelectionClassifier, SplitClassifier
27 from mvpa.clfs.transerror import TransferError
28 from mvpa.misc.transformers import Absolute
29
30 from mvpa.misc.state import UnknownStateError
31
32 from tests_warehouse import *
33 from tests_warehouse_clfs import *
34
36 """Simple one which just returns xrange[-N/2, N/2], where N is the
37 number of features
38 """
39
43
45 """Train linear SVM on `dataset` and extract weights from classifier.
46 """
47 return( self.__mult *( N.arange(dataset.nfeatures) - int(dataset.nfeatures/2) ))
48
49
51
54
57
58
60 bd = BestDetector()
61
62
63 self.failUnless(bd([]) == False)
64
65 self.failUnless(bd([1]) == True)
66
67 self.failUnless(bd([1, 0.9, 0.8]) == True)
68
69
70 bd = BestDetector(func=max)
71 self.failUnless(bd([0.8, 0.9, 1.0]) == True)
72 self.failUnless(bd([0.8, 0.9, 1.0]+[0.9]*9) == False)
73 self.failUnless(bd([0.8, 0.9, 1.0]+[0.9]*10) == False)
74
75
76 bd = BestDetector(lastminimum=True)
77 self.failUnless(bd([3, 2, 1, 1, 1, 2, 1]) == True)
78 bd = BestDetector()
79 self.failUnless(bd([3, 2, 1, 1, 1, 2, 1]) == False)
80
81
83 """Test stopping criterion"""
84 stopcrit = NBackHistoryStopCrit()
85
86 self.failUnless(stopcrit([]) == False)
87
88 self.failUnless(stopcrit(
89 [1, 0.9, 0.8]+[0.9]*(stopcrit.steps-1)) == False)
90
91 self.failUnless(stopcrit(
92 [1, 0.9, 0.8]+[0.9]*stopcrit.steps) == True)
93
94
95 stopcrit = NBackHistoryStopCrit(BestDetector(func=max))
96 self.failUnless(stopcrit([0.8, 0.9, 1.0]+[0.9]*9) == False)
97 self.failUnless(stopcrit([0.8, 0.9, 1.0]+[0.9]*10) == True)
98
99
100 stopcrit = NBackHistoryStopCrit(BestDetector(lastminimum=True))
101 self.failUnless(stopcrit([3, 2, 1, 1, 1, 2, 1]) == False)
102 stopcrit = NBackHistoryStopCrit(steps=4)
103 self.failUnless(stopcrit([3, 2, 1, 1, 1, 2, 1]) == True)
104
105
107 """Test stopping criterion"""
108 stopcrit = FixedErrorThresholdStopCrit(0.5)
109
110 self.failUnless(stopcrit([]) == False)
111 self.failUnless(stopcrit([0.8, 0.9, 0.5]) == False)
112 self.failUnless(stopcrit([0.8, 0.9, 0.4]) == True)
113
114 self.failUnless(stopcrit([0.8, 0.4, 0.6]) == False)
115
116
118 """Test stopping criterion"""
119 stopcrit = NStepsStopCrit(2)
120
121 self.failUnless(stopcrit([]) == False)
122 self.failUnless(stopcrit([0.8, 0.9]) == True)
123 self.failUnless(stopcrit([0.8]) == False)
124
125
127 """Test multiple stop criteria"""
128 stopcrit = MultiStopCrit([FixedErrorThresholdStopCrit(0.5),
129 NBackHistoryStopCrit(steps=4)])
130
131
132
133 self.failUnless(stopcrit([1, 0.9, 0.8]+[0.9]*4) == True)
134
135 self.failUnless(stopcrit([1, 0.9, 0.2]) == True)
136
137
138 stopcrit = MultiStopCrit([FixedErrorThresholdStopCrit(0.5),
139 NBackHistoryStopCrit(steps=4)],
140 mode = 'and')
141
142 self.failUnless(stopcrit([1, 0.9, 0.8]+[0.9]*4) == False)
143
144 self.failUnless(stopcrit([1, 0.9, 0.2]) == False)
145
146 self.failUnless(stopcrit([1, 0.9, 0.4]+[0.4]*4) == True)
147
148
150 """Test feature selector"""
151
152 selector = FractionTailSelector(0.1)
153 data = N.array([3.5, 10, 7, 5, -0.4, 0, 0, 2, 10, 9])
154
155 target10 = N.array([0, 1, 2, 3, 5, 6, 7, 8, 9])
156 target30 = N.array([0, 1, 2, 3, 7, 8, 9])
157
158 self.failUnlessRaises(UnknownStateError,
159 selector.__getattribute__, 'ndiscarded')
160 self.failUnless((selector(data) == target10).all())
161 selector.felements = 0.30
162 self.failUnless(selector.felements == 0.3)
163 self.failUnless((selector(data) == target30).all())
164 self.failUnless(selector.ndiscarded == 3)
165
166 selector = FixedNElementTailSelector(1)
167
168 data = N.array([3.5, 10, 7, 5, -0.4, 0, 0, 2, 10, 9])
169 self.failUnless((selector(data) == target10).all())
170
171 selector.nelements = 3
172 self.failUnless(selector.nelements == 3)
173 self.failUnless((selector(data) == target30).all())
174 self.failUnless(selector.ndiscarded == 3)
175
176
177
178 self.failUnless((RangeElementSelector(lower=0)(data) == \
179 N.array([0,1,2,3,7,8,9])).all())
180
181 self.failUnless((RangeElementSelector(lower=0,
182 inclusive=True)(data) == \
183 N.array([0,1,2,3,5,6,7,8,9])).all())
184
185 self.failUnless((RangeElementSelector(lower=0, mode='discard',
186 inclusive=True)(data) == \
187 N.array([4])).all())
188
189
190 self.failUnless((RangeElementSelector(upper=2)(data) == \
191 N.array([4,5,6])).all())
192
193 self.failUnless((RangeElementSelector(upper=2,
194 inclusive=True)(data) == \
195 N.array([4,5,6,7])).all())
196
197 self.failUnless((RangeElementSelector(upper=2, mode='discard',
198 inclusive=True)(data) == \
199 N.array([0,1,2,3,8,9])).all())
200
201
202
203 self.failUnless((RangeElementSelector(lower=2, upper=9)(data) == \
204 N.array([0,2,3])).all())
205
206 self.failUnless((RangeElementSelector(lower=2, upper=9,
207 inclusive=True)(data) == \
208 N.array([0,2,3,7,9])).all())
209
210 self.failUnless((RangeElementSelector(upper=2, lower=9, mode='discard',
211 inclusive=True)(data) ==
212 RangeElementSelector(lower=2, upper=9,
213 inclusive=False)(data)).all())
214
215
216 self.failUnless((RangeElementSelector()(data) == \
217 N.nonzero(data)[0]).all())
218
219
220 @sweepargs(clf=clfswh['has_sensitivity', '!meta'])
222
223
224 sens_ana = clf.getSensitivityAnalyzer()
225
226
227 Nremove = 2
228
229
230
231
232 fe = SensitivityBasedFeatureSelection(sens_ana,
233 feature_selector=FixedNElementTailSelector(2),
234 enable_states=["sensitivity", "selected_ids"])
235
236 wdata = self.getData()
237 wdata_nfeatures = wdata.nfeatures
238 tdata = self.getDataT()
239 tdata_nfeatures = tdata.nfeatures
240
241 sdata, stdata = fe(wdata, tdata)
242
243
244 self.failUnless(wdata.nfeatures == wdata_nfeatures)
245 self.failUnless(tdata.nfeatures == tdata_nfeatures)
246
247
248 self.failUnlessEqual(wdata.nfeatures, sdata.nfeatures+Nremove,
249 msg="We had to remove just a single feature")
250
251 self.failUnlessEqual(tdata.nfeatures, stdata.nfeatures+Nremove,
252 msg="We had to remove just a single feature in testing as well")
253
254 self.failUnlessEqual(len(fe.sensitivity), wdata_nfeatures,
255 msg="Sensitivity have to have # of features equal to original")
256
257 self.failUnlessEqual(len(fe.selected_ids), sdata.nfeatures,
258 msg="# of selected features must be equal the one in the result dataset")
259
260
262 sens_ana = SillySensitivityAnalyzer()
263
264 wdata = self.getData()
265 wdata_nfeatures = wdata.nfeatures
266 tdata = self.getDataT()
267 tdata_nfeatures = tdata.nfeatures
268
269
270 self.failUnlessEqual(sens_ana(wdata)[0], -int(wdata_nfeatures/2))
271
272
273
274 feature_selections = [SensitivityBasedFeatureSelection(
275 sens_ana,
276 FractionTailSelector(0.25)),
277 SensitivityBasedFeatureSelection(
278 sens_ana,
279 FixedNElementTailSelector(4))
280 ]
281
282
283 feat_sel_pipeline = FeatureSelectionPipeline(
284 feature_selections=feature_selections,
285 enable_states=['nfeatures', 'selected_ids'])
286
287 sdata, stdata = feat_sel_pipeline(wdata, tdata)
288
289 self.failUnlessEqual(len(feat_sel_pipeline.feature_selections),
290 len(feature_selections),
291 msg="Test the property feature_selections")
292
293 desired_nfeatures = int(N.ceil(wdata_nfeatures*0.75))
294 self.failUnlessEqual(feat_sel_pipeline.nfeatures,
295 [wdata_nfeatures, desired_nfeatures],
296 msg="Test if nfeatures get assigned properly."
297 " Got %s!=%s" % (feat_sel_pipeline.nfeatures,
298 [wdata_nfeatures, desired_nfeatures]))
299
300 self.failUnlessEqual(list(feat_sel_pipeline.selected_ids),
301 range(int(wdata_nfeatures*0.25)+4, wdata_nfeatures))
302
303
304
305 @sweepargs(clf=clfswh['has_sensitivity', '!meta'][:1])
352
353
354
355
356
388
389
390
391
394
395
396 if __name__ == '__main__':
397 import runner
398