40 template <
class FeatureType,
47 , num_of_features_(1000)
48 , num_of_thresholds_(10)
49 , feature_handler_(nullptr)
50 , stats_estimator_(nullptr)
56 template <
class FeatureType,
65 template <
class FeatureType,
74 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
75 const std::size_t num_of_examples = examples_.size();
78 std::vector<FeatureType> features;
79 feature_handler_->createRandomFeatures(num_of_features_, features);
85 std::vector<std::vector<float>> feature_results(num_of_features_);
86 std::vector<std::vector<unsigned char>> flags(num_of_features_);
88 for (std::size_t feature_index = 0; feature_index < num_of_features_;
90 feature_results[feature_index].reserve(num_of_examples);
91 flags[feature_index].reserve(num_of_examples);
93 feature_handler_->evaluateFeature(features[feature_index],
96 feature_results[feature_index],
97 flags[feature_index]);
101 std::vector<std::vector<std::vector<float>>> branch_feature_results(
103 std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
105 std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
107 std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
111 for (std::size_t feature_index = 0; feature_index < num_of_features_;
113 branch_feature_results[feature_index].resize(1);
114 branch_flags[feature_index].resize(1);
115 branch_examples[feature_index].resize(1);
116 branch_label_data[feature_index].resize(1);
118 branch_feature_results[feature_index][0] = feature_results[feature_index];
119 branch_flags[feature_index][0] = flags[feature_index];
120 branch_examples[feature_index][0] = examples_;
121 branch_label_data[feature_index][0] = label_data_;
124 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
126 std::vector<std::vector<float>> thresholds(num_of_features_);
128 for (std::size_t feature_index = 0; feature_index < num_of_features_;
130 thresholds.reserve(num_of_thresholds_);
131 createThresholdsUniform(num_of_thresholds_,
132 feature_results[feature_index],
133 thresholds[feature_index]);
137 int best_feature_index = -1;
138 float best_feature_threshold = 0.0f;
139 float best_feature_information_gain = 0.0f;
141 for (std::size_t feature_index = 0; feature_index < num_of_features_;
143 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
145 float information_gain = 0.0f;
146 for (std::size_t branch_index = 0;
147 branch_index < branch_feature_results[feature_index].size();
149 const float branch_information_gain =
150 stats_estimator_->computeInformationGain(
152 branch_examples[feature_index][branch_index],
153 branch_label_data[feature_index][branch_index],
154 branch_feature_results[feature_index][branch_index],
155 branch_flags[feature_index][branch_index],
156 thresholds[feature_index][threshold_index]);
159 branch_information_gain *
160 branch_feature_results[feature_index][branch_index].size();
163 if (information_gain > best_feature_information_gain) {
164 best_feature_information_gain = information_gain;
165 best_feature_index =
static_cast<int>(feature_index);
166 best_feature_threshold = thresholds[feature_index][threshold_index];
172 fern.
accessFeature(depth_index) = features[best_feature_index];
176 for (std::size_t feature_index = 0; feature_index < num_of_features_;
178 std::vector<std::vector<float>>& cur_branch_feature_results =
179 branch_feature_results[feature_index];
180 std::vector<std::vector<unsigned char>>& cur_branch_flags =
181 branch_flags[feature_index];
182 std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
183 branch_examples[feature_index];
184 std::vector<std::vector<LabelType>>& cur_branch_label_data =
185 branch_label_data[feature_index];
187 const std::size_t total_num_of_new_branches =
188 num_of_branches * cur_branch_feature_results.size();
190 std::vector<std::vector<float>> new_branch_feature_results(
191 total_num_of_new_branches);
192 std::vector<std::vector<unsigned char>> new_branch_flags(
193 total_num_of_new_branches);
194 std::vector<std::vector<ExampleIndex>> new_branch_examples(
195 total_num_of_new_branches);
196 std::vector<std::vector<LabelType>> new_branch_label_data(
197 total_num_of_new_branches);
199 for (std::size_t branch_index = 0;
200 branch_index < cur_branch_feature_results.size();
202 const std::size_t num_of_examples_in_this_branch =
203 cur_branch_feature_results[branch_index].size();
205 std::vector<unsigned char> branch_indices;
206 branch_indices.reserve(num_of_examples_in_this_branch);
208 stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
209 cur_branch_flags[branch_index],
210 best_feature_threshold,
214 const std::size_t base_branch_index = branch_index * num_of_branches;
215 for (std::size_t example_index = 0;
216 example_index < num_of_examples_in_this_branch;
218 const std::size_t combined_branch_index =
219 base_branch_index + branch_indices[example_index];
221 new_branch_feature_results[combined_branch_index].push_back(
222 cur_branch_feature_results[branch_index][example_index]);
223 new_branch_flags[combined_branch_index].push_back(
224 cur_branch_flags[branch_index][example_index]);
225 new_branch_examples[combined_branch_index].push_back(
226 cur_branch_examples[branch_index][example_index]);
227 new_branch_label_data[combined_branch_index].push_back(
228 cur_branch_label_data[branch_index][example_index]);
232 branch_feature_results[feature_index] = new_branch_feature_results;
233 branch_flags[feature_index] = new_branch_flags;
234 branch_examples[feature_index] = new_branch_examples;
235 branch_label_data[feature_index] = new_branch_label_data;
241 std::vector<std::vector<float>> final_feature_results(
243 std::vector<std::vector<unsigned char>> final_flags(
245 std::vector<std::vector<unsigned char>> final_branch_indices(
247 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
248 final_feature_results[depth_index].reserve(num_of_examples);
249 final_flags[depth_index].reserve(num_of_examples);
250 final_branch_indices[depth_index].reserve(num_of_examples);
252 feature_handler_->evaluateFeature(fern.
accessFeature(depth_index),
255 final_feature_results[depth_index],
256 final_flags[depth_index]);
258 stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
259 final_flags[depth_index],
261 final_branch_indices[depth_index]);
265 std::vector<std::vector<LabelType>> node_labels(
267 std::vector<std::vector<ExampleIndex>> node_examples(
270 for (std::size_t example_index = 0; example_index < num_of_examples;
272 std::size_t node_index = 0;
273 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
274 node_index *= num_of_branches;
275 node_index += final_branch_indices[depth_index][example_index];
278 node_labels[node_index].push_back(label_data_[example_index]);
279 node_examples[node_index].push_back(examples_[example_index]);
283 const std::size_t num_of_nodes = 0x1 << fern_depth_;
284 for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
285 stats_estimator_->computeAndSetNodeStats(data_set_,
286 node_examples[node_index],
287 node_labels[node_index],
292 template <
class FeatureType,
300 std::vector<float>& values,
301 std::vector<float>& thresholds)
304 float min_value = ::std::numeric_limits<float>::max();
305 float max_value = -::std::numeric_limits<float>::max();
307 const std::size_t num_of_values = values.size();
308 for (
int value_index = 0; value_index < num_of_values; ++value_index) {
309 const float value = values[value_index];
311 if (value < min_value)
313 if (value > max_value)
317 const float range = max_value - min_value;
318 const float step = range / (num_of_thresholds + 2);
321 thresholds.resize(num_of_thresholds);
323 for (
int threshold_index = 0; threshold_index < num_of_thresholds;
325 thresholds[threshold_index] = min_value + step * (threshold_index + 1);