1
2
3
4
5
6
7
8
9 """Provide sensitivity measures for libsvm's SVM."""
10
11 __docformat__ = 'restructuredtext'
12
13 import numpy as N
14
15 from mvpa.base import warning
16 from mvpa.misc.state import StateVariable
17 from mvpa.misc.param import Parameter
18 from mvpa.measures.base import Sensitivity
19
20 if __debug__:
21 from mvpa.base import debug
22
24 """`SensitivityAnalyzer` for the LIBSVM implementation of a linear SVM.
25 """
26
27 _ATTRIBUTE_COLLECTIONS = ['params']
28
29 biases = StateVariable(enabled=True,
30 doc="Offsets of separating hyperplanes")
31
32 split_weights = Parameter(False, allowedtype='bool',
33 doc="If binary classification either to sum SVs per each "
34 "class separately")
35
37 """Initialize the analyzer with the classifier it shall use.
38
39 :Parameters:
40 clf: LinearSVM
41 classifier to use. Only classifiers sub-classed from
42 `LinearSVM` may be used.
43 """
44
45 Sensitivity.__init__(self, clf, **kwargs)
46
47
48 - def _call(self, dataset, callables=[]):
49
50 model = self.clf.model
51 nr_class = model.nr_class
52
53 if nr_class != 2:
54 warning("You are estimating sensitivity for SVM %s trained on %d" %
55 (str(self.clf), self.clf.model.nr_class) +
56 " classes. Make sure that it is what you intended to do" )
57
58 svcoef = N.matrix(model.getSVCoef())
59 svs = N.matrix(model.getSV())
60 rhos = N.asarray(model.getRho())
61
62 self.biases = rhos
63 if self.split_weights:
64 if nr_class != 2:
65 raise NotImplementedError, \
66 "Cannot compute per-class weights for" \
67 " non-binary classification task"
68
69
70 svm_labels = model.getLabels()
71 ds_labels = list(dataset.uniquelabels)
72 senses = [None for i in ds_labels]
73
74 for i, (c, l) in enumerate( [(svcoef > 0, lambda x: x),
75 (svcoef < 0, lambda x: x*-1)] ):
76
77 c_ = c.A[0]
78 senses[ds_labels.index(svm_labels[i])] = \
79 (l(svcoef[:, c_] * svs[c_, :])).A[0]
80 weights = N.array(senses)
81 else:
82
83
84
85
86
87
88
89
90
91 weights = svcoef * svs
92
93 if __debug__:
94 debug('SVM',
95 "Extracting weights for %d-class SVM: #SVs=%s, " % \
96 (nr_class, str(model.getNSV())) + \
97 " SVcoefshape=%s SVs.shape=%s Rhos=%s." % \
98 (svcoef.shape, svs.shape, rhos) + \
99 " Result: min=%f max=%f" % (N.min(weights), N.max(weights)))
100
101 return N.asarray(weights.T)
102
103 _customizeDocInherit = True
104