SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/machine/Machine.h>
15 #include <shogun/base/Parameter.h>
17 
18 using namespace shogun;
19 
21 {
22  init();
23 }
24 
26  CLabels* labels, CSplittingStrategy* splitting_strategy,
27  CEvaluation* evaluation_criterium)
28 {
29  init();
30 
31  m_machine=machine;
32  m_features=features;
33  m_labels=labels;
34  m_splitting_strategy=splitting_strategy;
35  m_evaluation_criterium=evaluation_criterium;
36 
37  SG_REF(m_machine);
38  SG_REF(m_features);
39  SG_REF(m_labels);
40  SG_REF(m_splitting_strategy);
41  SG_REF(m_evaluation_criterium);
42 }
43 
45 {
46  SG_UNREF(m_machine);
47  SG_UNREF(m_features);
48  SG_UNREF(m_labels);
49  SG_UNREF(m_splitting_strategy);
50  SG_UNREF(m_evaluation_criterium);
51 }
52 
54 {
55  return m_evaluation_criterium->get_evaluation_direction();
56 }
57 
58 void CCrossValidation::init()
59 {
60  m_machine=NULL;
61  m_features=NULL;
62  m_labels=NULL;
63  m_splitting_strategy=NULL;
64  m_evaluation_criterium=NULL;
65  m_num_runs=1;
66  m_conf_int_alpha=0;
67 
68  m_parameters->add((CSGObject**) &m_machine, "machine",
69  "Used learning machine");
70  m_parameters->add((CSGObject**) &m_features, "features", "Used features");
71  m_parameters->add((CSGObject**) &m_labels, "labels", "Used labels");
72  m_parameters->add((CSGObject**) &m_splitting_strategy,
73  "splitting_strategy", "Used splitting strategy");
74  m_parameters->add((CSGObject**) &m_evaluation_criterium,
75  "evaluation_criterium", "Used evaluation criterium");
76  m_parameters->add(&m_num_runs, "num_runs", "Number of repetitions");
77  m_parameters->add(&m_conf_int_alpha, "conf_int_alpha", "alpha-value of confidence "
78  "interval");
79 }
80 
82 {
83  SG_REF(m_machine);
84  return m_machine;
85 }
86 
88 {
89  SGVector<float64_t> results(m_num_runs);
90 
91  for (index_t i=0; i<m_num_runs; ++i)
92  results.vector[i]=evaluate_one_run();
93 
94  /* construct evaluation result */
95  CrossValidationResult result;
96  result.has_conf_int=m_conf_int_alpha!=0;
97  result.conf_int_alpha=m_conf_int_alpha;
98 
99  if (result.has_conf_int)
100  {
101  result.conf_int_alpha=m_conf_int_alpha;
103  result.conf_int_alpha, result.conf_int_low, result.conf_int_up);
104  }
105  else
106  {
107  result.mean=CStatistics::mean(results);
108  result.conf_int_low=0;
109  result.conf_int_up=0;
110  }
111 
112  SG_FREE(results.vector);
113 
114  return result;
115 }
116 
118 {
119  if (conf_int_alpha<0 || conf_int_alpha>=1)
120  {
121  SG_ERROR("%f is an illegal alpha-value for confidence interval of "
122  "cross-validation\n", conf_int_alpha);
123  }
124 
125  m_conf_int_alpha=conf_int_alpha;
126 }
127 
128 void CCrossValidation::set_num_runs(int32_t num_runs)
129 {
130  if (num_runs<1)
131  SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
132 
133  m_num_runs=num_runs;
134 }
135 
137 {
138  index_t num_subsets=m_splitting_strategy->get_num_subsets();
139  float64_t* results=SG_MALLOC(float64_t, num_subsets);
140 
141  /* set labels to machine */
142  m_machine->set_labels(m_labels);
143 
144  /* tell machine to store model internally
145  * (otherwise changing subset of features will kaboom the classifier) */
146  m_machine->set_store_model_features(true);
147 
148  /* do actual cross-validation */
149  for (index_t i=0; i<num_subsets; ++i)
150  {
151  /* set feature subset for training */
152  SGVector<index_t> inverse_subset_indices=
153  m_splitting_strategy->generate_subset_inverse(i);
154  m_features->set_subset(new CSubset(inverse_subset_indices));
155 
156  /* set label subset for training (copy data before) */
157  SGVector<index_t> inverse_subset_indices_copy(
158  inverse_subset_indices.vlen);
159  memcpy(inverse_subset_indices_copy.vector,
160  inverse_subset_indices.vector,
161  inverse_subset_indices.vlen*sizeof(index_t));
162  m_labels->set_subset(new CSubset(inverse_subset_indices_copy));
163 
164  /* train machine on training features */
165  m_machine->train(m_features);
166 
167  /* set feature subset for testing (subset method that stores pointer) */
168  SGVector<index_t> subset_indices=
169  m_splitting_strategy->generate_subset_indices(i);
170  m_features->set_subset(new CSubset(subset_indices));
171 
172  /* apply machine to test features */
173  CLabels* result_labels=m_machine->apply(m_features);
174  SG_REF(result_labels);
175 
176  /* set label subset for testing (copy data before) */
177  SGVector<index_t> subset_indices_copy(subset_indices.vlen);
178  memcpy(subset_indices_copy.vector, subset_indices.vector,
179  subset_indices.vlen*sizeof(index_t));
180  m_labels->set_subset(new CSubset(subset_indices_copy));
181 
182  /* evaluate */
183  results[i]=m_evaluation_criterium->evaluate(result_labels, m_labels);
184 
185  /* clean up, reset subsets */
186  SG_UNREF(result_labels);
187  m_features->remove_subset();
188  m_labels->remove_subset();
189  }
190 
191  /* build arithmetic mean of results */
192  float64_t mean=CStatistics::mean(SGVector<float64_t>(results, num_subsets));
193 
194  /* clean up */
195  SG_FREE(results);
196 
197  return mean;
198 }

SHOGUN Machine Learning Toolbox - Documentation