1
    2
    3
    4
    5
    6
    7
    8
    9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   21
   22
   23
   24
   25
   26
   27
   28
   29
   30
   31
   32
   33
   34
   35
   36
   37
   38
   39
   40
   41
   42
   43
   44
   45
   46
   47
   48
   49
   50
   51
   52
   53
   54
   55
   56
   57
   58
   59
   60
   61
   62
   63
   64
   65
   66
   67
   68
   69
   70
   71
   72
   73
   74
   75
   76
   77
   78
   79
   80
   81
   82
   83
   84
   85
   86
   87
   88
   89
   90
   91
   92
   93
   94
   95
   96
   97
   98
   99
  100
  101
  102
  103
  104
  105
  106
  107
  108
  109
  110
  111
  112
  113
  114
  115
  116
  117
  118
  119
  120
  121
  122
  123
  124
  125
  126
  127
  128
  129
  130
  131
  132
  133
  134
  135
  136
  137
  138
  139
  140
  141
  142
  143
  144
  145
  146
  147
  148
  149
  150
  151
  152
  153
  154
  155
  156
  157
  158
  159
  160
  161
  162
  163
  164
  165
  166
  167
  168
  169
  170
  171
  172
  173
  174
  175
  176
  177
  178
  179
  180
  181
  182
  183
  184
  185
  186
  187
  188

media / learning / impl / random_tree_trainer.h [blame]

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
#define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_

#include <limits>
#include <map>
#include <memory>
#include <set>

#include "base/component_export.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/random_number_generator.h"
#include "media/learning/impl/training_algorithm.h"

namespace media {
namespace learning {

// Trains RandomTree decision tree classifier / regressor.
//
// Decision trees, including RandomTree, classify instances as follows.  Each
// non-leaf node is marked with a feature number |i|.  The value of the |i|-th
// feature of the instance is then used to select which outgoing edge is
// traversed.  This repeats until arriving at a leaf, which has a distribution
// over target values that is the prediction.  The tree structure, including the
// feature index at each node and distribution at each leaf, is chosen once when
// the tree is trained.
//
// Training involves starting with a set of training examples, each of which has
// features and a target value.  The tree is constructed recursively, starting
// with the root.  For the node being constructed, the training algorithm is
// given the portion of the training set that would reach the node, if it were
// sent down the tree in a similar fashion as described above.  It then
// considers assigning each (unused) feature index as the index to split the
// training examples at this node.  For each index |t|, it groups the training
// set into subsets, each of which consists of those examples with the same
// of the |i|-th feature.  It then computes a score for the split using the
// target values that ended up in each group.  The index with the best score is
// chosen for the split.
//
// For nominal features, we split the feature into all of its nominal values.
// This is somewhat nonstandard; one would normally convert to one-hot numeric
// features first.  See OneHotConverter if you'd like to do this.
//
// For numeric features, we choose a split point uniformly at random between its
// min and max values in the training data.  We do this because it's suitable
// for extra trees.  RandomForest trees want to select the best split point for
// each feature, rather than uniformly.  Either way, of course, we choose the
// best split among the (feature, split point) pairs we're considering.
//
// Also note that for one-hot features, these are the same thing.  So, this
// implementation is suitable for extra trees with numeric (possibly one hot)
// features, or RF with one-hot nominal features.  Note that non-one-hot nominal
// features probably work fine with RF too.  Numeric, non-binary features don't
// work with RF, unless one changes the split point selection.
//
// The training algorithm then recurses to build child nodes.  One child node is
// created for each observed value of the |i|-th feature in the training set.
// The child node is trained using the subset of the training set that shares
// that node's value for feature |i|.
//
// The above is a generic decision tree training algorithm.  A RandomTree
// differs from that mostly in how it selects the feature to split at each node
// during training.  Rather than computing a score for each feature, a
// RandomTree chooses a random subset of the features and only compares those.
//
// See https://en.wikipedia.org/wiki/Random_forest for information.  Note that
// this is just a single tree, not the whole forest.
//
// Note that this variant chooses split points randomly, as described by the
// ExtraTrees algorithm.  This is slightly different than RandomForest, which
// chooses split points to improve the split's score.
class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer
    : public TrainingAlgorithm,
      public HasRandomNumberGenerator {
 public:
  explicit RandomTreeTrainer(RandomNumberGenerator* rng = nullptr);

  RandomTreeTrainer(const RandomTreeTrainer&) = delete;
  RandomTreeTrainer& operator=(const RandomTreeTrainer&) = delete;

  ~RandomTreeTrainer() override;

  // Train on all examples.  Calls |model_cb| with the trained model, which
  // won't happen before this returns.
  void Train(const LearningTask& task,
             const TrainingData& examples,
             TrainedModelCB model_cb) override;

 private:
  // Train on the subset |training_idx|.
  std::unique_ptr<Model> Train(const LearningTask& task,
                               const TrainingData& examples,
                               const std::vector<size_t>& training_idx);

  // Set of feature indices.
  using FeatureSet = std::set<int>;

  // Information about a proposed split, and the training sets that would result
  // from that split.
  struct Split {
    Split();
    explicit Split(int index);

    Split(const Split&) = delete;
    Split& operator=(const Split&) = delete;

    Split(Split&& rhs);

    ~Split();

    Split& operator=(Split&& rhs);

    // Feature index to split on.
    size_t split_index = 0;

    // For numeric splits, branch 0 is <= |split_point|, and 1 is > .
    FeatureValue split_point;

    // Expected nats needed to compute the class, given that we're at this
    // node in the tree.
    // "nat" == entropy measured with natural log rather than base-2.
    double nats_remaining = std::numeric_limits<double>::infinity();

    // Per-branch (i.e. per-child node) information about this split.
    struct BranchInfo {
      explicit BranchInfo();
      BranchInfo(const BranchInfo& rhs) = delete;
      BranchInfo(BranchInfo&& rhs);
      ~BranchInfo();

      BranchInfo& operator=(const BranchInfo& rhs) = delete;
      BranchInfo& operator=(BranchInfo&& rhs) = delete;

      // Training set for this branch of the split.  |training_idx| holds the
      // indices that we're using out of our training data.
      std::vector<size_t> training_idx;

      // Number of occurrences of each target value in |training_data| along
      // this branch of the split.
      // This is a flat_map since we're likely to have a very small (e.g.,
      // "true / "false") number of targets.
      TargetHistogram target_histogram;
    };

    // [feature value at this split] = info about which examples take this
    // branch of the split.
    std::map<FeatureValue, BranchInfo> branch_infos;
  };

  // Build this node from |training_data|.  |used_set| is the set of features
  // that we already used higher in the tree.
  std::unique_ptr<Model> Build(const LearningTask& task,
                               const TrainingData& training_data,
                               const std::vector<size_t>& training_idx,
                               const FeatureSet& used_set);

  // Compute and return a split of |training_data| on the |index|-th feature.
  Split ConstructSplit(const LearningTask& task,
                       const TrainingData& training_data,
                       const std::vector<size_t>& training_idx,
                       int index);

  // Fill in |nats_remaining| for |split| for a nominal target.
  // |total_incoming_weight| is the total weight of all instances coming into
  // the node that we're splitting.
  void ComputeSplitScore_Nominal(Split* split, double total_incoming_weight);

  // Fill in |nats_remaining| for |split| for a numeric target.
  void ComputeSplitScore_Numeric(Split* split, double total_incoming_weight);

  // Compute the split point for |training_data| for a nominal feature.
  FeatureValue FindSplitPoint_Nominal(size_t index,
                                      const TrainingData& training_data,
                                      const std::vector<size_t>& training_idx);

  // Compute the split point for |training_data| for a numeric feature.
  FeatureValue FindSplitPoint_Numeric(size_t index,
                                      const TrainingData& training_data,
                                      const std::vector<size_t>& training_idx);
};

}  // namespace learning
}  // namespace media

#endif  // MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_