Point Cloud Library (PCL)  1.8.1
fern_trainer.hpp
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #ifndef PCL_ML_FERNS_FERN_TRAINER_HPP_
39 #define PCL_ML_FERNS_FERN_TRAINER_HPP_
40 
41 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
42 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
44  : fern_depth_ (10)
45  , num_of_features_ (1000)
46  , num_of_thresholds_ (10)
47  , feature_handler_ (NULL)
48  , stats_estimator_ (NULL)
49  , data_set_ ()
50  , label_data_ ()
51  , examples_ ()
52 {
53 
54 }
55 
56 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
57 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
59 {
60 
61 }
62 
63 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
64 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
65 void
68 {
69  const size_t num_of_branches = stats_estimator_->getNumOfBranches ();
70  const size_t num_of_examples = examples_.size ();
71 
72  // create random features
73  std::vector<FeatureType> features;
74  feature_handler_->createRandomFeatures (num_of_features_, features);
75 
76  // setup fern
77  fern.initialize (fern_depth_);
78 
79  // evaluate all features
80  std::vector<std::vector<float> > feature_results (num_of_features_);
81  std::vector<std::vector<unsigned char> > flags (num_of_features_);
82 
83  for (size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
84  {
85  feature_results[feature_index].reserve (num_of_examples);
86  flags[feature_index].reserve (num_of_examples);
87 
88  feature_handler_->evaluateFeature (features[feature_index],
89  data_set_,
90  examples_,
91  feature_results[feature_index],
92  flags[feature_index] );
93  }
94 
95  // iteratively select features and thresholds
96  std::vector<std::vector<std::vector<float> > > branch_feature_results (num_of_features_); // [feature_index][branch_index][result_index]
97  std::vector<std::vector<std::vector<unsigned char> > > branch_flags (num_of_features_); // [feature_index][branch_index][flag_index]
98  std::vector<std::vector<std::vector<ExampleIndex> > > branch_examples (num_of_features_); // [feature_index][branch_index][result_index]
99  std::vector<std::vector<std::vector<LabelType> > > branch_label_data (num_of_features_); // [feature_index][branch_index][flag_index]
100 
101  // - initialize branch feature results and flags
102  for (size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
103  {
104  branch_feature_results[feature_index].resize (1);
105  branch_flags[feature_index].resize (1);
106  branch_examples[feature_index].resize (1);
107  branch_label_data[feature_index].resize (1);
108 
109  branch_feature_results[feature_index][0] = feature_results[feature_index];
110  branch_flags[feature_index][0] = flags[feature_index];
111  branch_examples[feature_index][0] = examples_;
112  branch_label_data[feature_index][0] = label_data_;
113  }
114 
115  for (size_t depth_index = 0; depth_index < fern_depth_; ++depth_index)
116  {
117  // get thresholds
118  std::vector<std::vector<float> > thresholds (num_of_features_);
119 
120  for (size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
121  {
122  thresholds.reserve (num_of_thresholds_);
123  createThresholdsUniform (num_of_thresholds_, feature_results[feature_index], thresholds[feature_index]);
124  }
125 
126  // compute information gain
127  int best_feature_index = -1;
128  float best_feature_threshold = 0.0f;
129  float best_feature_information_gain = 0.0f;
130 
131  for (size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
132  {
133  for (size_t threshold_index = 0; threshold_index < num_of_thresholds_; ++threshold_index)
134  {
135  float information_gain = 0.0f;
136  for (size_t branch_index = 0; branch_index < branch_feature_results[feature_index].size (); ++branch_index)
137  {
138  const float branch_information_gain = stats_estimator_->computeInformationGain (data_set_,
139  branch_examples[feature_index][branch_index],
140  branch_label_data[feature_index][branch_index],
141  branch_feature_results[feature_index][branch_index],
142  branch_flags[feature_index][branch_index],
143  thresholds[feature_index][threshold_index]);
144 
145  information_gain += branch_information_gain * branch_feature_results[feature_index][branch_index].size ();
146  }
147 
148  if (information_gain > best_feature_information_gain)
149  {
150  best_feature_information_gain = information_gain;
151  best_feature_index = static_cast<int> (feature_index);
152  best_feature_threshold = thresholds[feature_index][threshold_index];
153  }
154  }
155  }
156 
157  // add feature to the feature list of the fern
158  fern.accessFeature (depth_index) = features[best_feature_index];
159  fern.accessThreshold (depth_index) = best_feature_threshold;
160 
161  // update branch feature results and flags
162  for (size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
163  {
164  std::vector<std::vector<float> > & cur_branch_feature_results = branch_feature_results[feature_index];
165  std::vector<std::vector<unsigned char> > & cur_branch_flags = branch_flags[feature_index];
166  std::vector<std::vector<ExampleIndex> > & cur_branch_examples = branch_examples[feature_index];
167  std::vector<std::vector<LabelType> > & cur_branch_label_data = branch_label_data[feature_index];
168 
169  const size_t total_num_of_new_branches = num_of_branches * cur_branch_feature_results.size ();
170 
171  std::vector<std::vector<float> > new_branch_feature_results (total_num_of_new_branches); // [branch_index][example_index]
172  std::vector<std::vector<unsigned char> > new_branch_flags (total_num_of_new_branches); // [branch_index][example_index]
173  std::vector<std::vector<ExampleIndex> > new_branch_examples (total_num_of_new_branches); // [branch_index][example_index]
174  std::vector<std::vector<LabelType> > new_branch_label_data (total_num_of_new_branches); // [branch_index][example_index]
175 
176  for (size_t branch_index = 0; branch_index < cur_branch_feature_results.size (); ++branch_index)
177  {
178  const size_t num_of_examples_in_this_branch = cur_branch_feature_results[branch_index].size ();
179 
180  std::vector<unsigned char> branch_indices;
181  branch_indices.reserve (num_of_examples_in_this_branch);
182 
183  stats_estimator_->computeBranchIndices (cur_branch_feature_results[branch_index],
184  cur_branch_flags[branch_index],
185  best_feature_threshold,
186  branch_indices);
187 
188  // split results into different branches
189  const size_t base_branch_index = branch_index * num_of_branches;
190  for (size_t example_index = 0; example_index < num_of_examples_in_this_branch; ++example_index)
191  {
192  const size_t combined_branch_index = base_branch_index + branch_indices[example_index];
193 
194  new_branch_feature_results[combined_branch_index].push_back (cur_branch_feature_results[branch_index][example_index]);
195  new_branch_flags[combined_branch_index].push_back (cur_branch_flags[branch_index][example_index]);
196  new_branch_examples[combined_branch_index].push_back (cur_branch_examples[branch_index][example_index]);
197  new_branch_label_data[combined_branch_index].push_back (cur_branch_label_data[branch_index][example_index]);
198  }
199  }
200 
201  branch_feature_results[feature_index] = new_branch_feature_results;
202  branch_flags[feature_index] = new_branch_flags;
203  branch_examples[feature_index] = new_branch_examples;
204  branch_label_data[feature_index] = new_branch_label_data;
205  }
206  }
207 
208  // set node statistics
209  // - re-evaluate selected features
210  std::vector<std::vector<float> > final_feature_results (fern_depth_); // [feature_index][example_index]
211  std::vector<std::vector<unsigned char> > final_flags (fern_depth_); // [feature_index][example_index]
212  std::vector<std::vector<unsigned char> > final_branch_indices (fern_depth_); // [feature_index][example_index]
213  for (size_t depth_index = 0; depth_index < fern_depth_; ++depth_index)
214  {
215  final_feature_results[depth_index].reserve (num_of_examples);
216  final_flags[depth_index].reserve (num_of_examples);
217  final_branch_indices[depth_index].reserve (num_of_examples);
218 
219  feature_handler_->evaluateFeature (fern.accessFeature (depth_index),
220  data_set_,
221  examples_,
222  final_feature_results[depth_index],
223  final_flags[depth_index] );
224 
225  stats_estimator_->computeBranchIndices (final_feature_results[depth_index],
226  final_flags[depth_index],
227  fern.accessThreshold (depth_index),
228  final_branch_indices[depth_index]);
229  }
230 
231  // - distribute examples to nodes
232  std::vector<std::vector<LabelType> > node_labels (0x1 << fern_depth_); // [node_index][example_index]
233  std::vector<std::vector<ExampleIndex> > node_examples (0x1 << fern_depth_); // [node_index][example_index]
234 
235  for (size_t example_index = 0; example_index < num_of_examples; ++example_index)
236  {
237  size_t node_index = 0;
238  for (size_t depth_index = 0; depth_index < fern_depth_; ++depth_index)
239  {
240  node_index *= num_of_branches;
241  node_index += final_branch_indices[depth_index][example_index];
242  }
243 
244  node_labels[node_index].push_back (label_data_[example_index]);
245  node_examples[node_index].push_back (examples_[example_index]);
246  }
247 
248  // - compute and set statistics for every node
249  const size_t num_of_nodes = 0x1 << fern_depth_;
250  for (size_t node_index = 0; node_index < num_of_nodes; ++node_index)
251  {
252  stats_estimator_->computeAndSetNodeStats (data_set_, node_examples[node_index], node_labels[node_index], fern[node_index]);
253  }
254 }
255 
256 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
257 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
258 void
260  const size_t num_of_thresholds,
261  std::vector<float> & values,
262  std::vector<float> & thresholds)
263 {
264  // estimate range of values
265  float min_value = ::std::numeric_limits<float>::max();
266  float max_value = -::std::numeric_limits<float>::max();
267 
268  const size_t num_of_values = values.size ();
269  for (int value_index = 0; value_index < num_of_values; ++value_index)
270  {
271  const float value = values[value_index];
272 
273  if (value < min_value) min_value = value;
274  if (value > max_value) max_value = value;
275  }
276 
277  const float range = max_value - min_value;
278  const float step = range / (num_of_thresholds+2);
279 
280  // compute thresholds
281  thresholds.resize (num_of_thresholds);
282 
283  for (int threshold_index = 0; threshold_index < num_of_thresholds; ++threshold_index)
284  {
285  thresholds[threshold_index] = min_value + step*(threshold_index+1);
286  }
287 }
288 
289 #endif
static void createThresholdsUniform(const size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
float & accessThreshold(const size_t threshold_index)
Access operator for thresholds.
Definition: fern.h:186
virtual void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const =0
Computes the branch indices obtained by the specified threshold on the supplied feature evaluation re...
FeatureType & accessFeature(const size_t feature_index)
Access operator for features.
Definition: fern.h:168
virtual float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const =0
Computes the information gain obtained by the specified threshold on the supplied feature evaluation ...
virtual size_t getNumOfBranches() const =0
Returns the number of brances a node can have (e.g.
Class representing a Fern.
Definition: fern.h:51
virtual ~FernTrainer()
Destructor.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.
void initialize(const size_t num_of_decisions)
Initializes the fern.
Definition: fern.h:71
virtual void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const =0
Computes and sets the statistics for a node.