SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
SegmentLoss.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2009 Jonas Behr
8  * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 #ifndef __SEGMENT_LOSS__
11 #define __SEGMENT_LOSS__
12 
13 #include <shogun/lib/common.h>
14 #include <shogun/base/SGObject.h>
15 #include <shogun/lib/Array.h>
16 #include <shogun/lib/Array2.h>
17 #include <shogun/lib/Array3.h>
18 
19 
20 namespace shogun
21 {
22  template <class T> class CArray;
23  template <class T> class CArray2;
24  template <class T> class CArray3;
26 class CSegmentLoss : public CSGObject
27 {
28  public:
29 
32  CSegmentLoss();
33 
34  virtual ~CSegmentLoss();
35 
42  float32_t get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id);
43 
50  float32_t get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id);
51 
58  void set_segment_loss(float64_t* segment_loss, int32_t m, int32_t n);
59 
64  void set_segment_ids(CArray<int32_t>* segment_ids);
65 
72  void set_segment_mask(CArray<float64_t>* segment_mask);
73 
78  void set_num_segment_types(int32_t num_segment_types)
79  {
80  m_num_segment_types = num_segment_types;
81  }
82 
88  void compute_loss(int32_t* all_pos, int32_t len);
89 
93  inline virtual const char* get_name() const { return "SegmentLoss"; }
94  protected:
95 
98 
104 
107 
110 
113 };
114 
115 inline float32_t CSegmentLoss::get_segment_loss(int32_t from_pos, int32_t to_pos, int32_t segment_id)
116 {
117 
118  /* int32_t from_pos_shift = from_pos ;
119  if (print)
120  SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
121  from_pos_shift, to_pos, segment_id,
122  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
123  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
124  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
125  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;
126  while(1)
127  {
128  while (m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1) && from_pos_shift<to_pos)
129  from_pos_shift++ ;
130  if (print)
131  SG_PRINT("# pos=%i,%i segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
132  from_pos_shift, to_pos, segment_id,
133  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
134  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
135  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
136  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;
137 
138  if (from_pos_shift>=to_pos)
139  {
140  //SG_PRINT("break") ;
141  break ;
142  }
143  else from_pos_shift++ ;
144  }
145  if (print)
146  SG_PRINT("break\n") ; */
147 
148  float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos)-m_segment_loss_matrix.element(segment_id, to_pos);
149  diff_contrib += m_segment_mask->element(to_pos-1)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos-1), 0);
150  return diff_contrib;
151 }
152 
153 inline float32_t CSegmentLoss::get_segment_loss_extend(int32_t from_pos, int32_t to_pos, int32_t segment_id)
154 {
155  int32_t from_pos_shift = from_pos ;
156 
157  /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
158  segment_id,
159  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
160  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
161  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
162  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/
163 
164  while (from_pos_shift<to_pos && m_segment_ids->element(from_pos_shift)==m_segment_ids->element(from_pos_shift+1))
165  from_pos_shift++ ;
166 
167  /*SG_PRINT("segment_id=%i, m_segment_ids[from-2]=%i (%1.1f), m_segment_ids[from-1]=%i (%1.1f), m_segment_ids[from]=%i (%1.1f), m_segment_ids[from+1]=%i (%1.1f), \n",
168  segment_id,
169  m_segment_ids->element(from_pos_shift-2), m_segment_loss_matrix.element(segment_id, from_pos_shift-2)-m_segment_loss_matrix.element(segment_id, to_pos),
170  m_segment_ids->element(from_pos_shift-1), m_segment_loss_matrix.element(segment_id, from_pos_shift-1)-m_segment_loss_matrix.element(segment_id, to_pos),
171  m_segment_ids->element(from_pos_shift), m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos),
172  m_segment_ids->element(from_pos_shift+1), m_segment_loss_matrix.element(segment_id, from_pos_shift+1)-m_segment_loss_matrix.element(segment_id, to_pos)) ;*/
173 
174  float32_t diff_contrib = m_segment_loss_matrix.element(segment_id, from_pos_shift)-m_segment_loss_matrix.element(segment_id, to_pos);
175  //diff_contrib += m_segment_mask->element(to_pos)*m_segment_loss.element(segment_id, m_segment_ids->element(to_pos), 0);
176 
177  //if (from_pos_shift!=from_pos)
178  // SG_PRINT("shifting from %i to %i, to_pos=%i, loss=%1.1f\n", from_pos, from_pos_shift, to_pos, diff_contrib) ;
179 
180  return diff_contrib;
181 }
182 }
183 #endif

SHOGUN Machine Learning Toolbox - Documentation