1
    2
    3
    4
    5
    6
    7
    8
    9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   21
   22
   23
   24
   25
   26
   27
   28
   29
   30
   31
   32
   33
   34
   35
   36
   37
   38
   39
   40
   41
   42
   43
   44
   45
   46
   47
   48
   49
   50
   51
   52
   53
   54
   55
   56
   57
   58
   59
   60
   61
   62
   63
   64
   65
   66
   67
   68
   69
   70
   71
   72
   73
   74
   75
   76
   77
   78
   79
   80
   81
   82
   83
   84
   85
   86
   87
   88
   89
   90
   91
   92
   93
   94
   95
   96
   97
   98
   99
  100
  101
  102
  103
  104
  105
  106
  107
  108
  109
  110
  111
  112
  113
  114
  115
  116
  117
  118
  119

media / learning / common / labelled_example.h [blame]

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_

#include <initializer_list>
#include <ostream>
#include <vector>

#include "base/check_op.h"
#include "base/component_export.h"
#include "media/learning/common/value.h"

namespace media {
namespace learning {

// Vector of features, for training or prediction.
// To interpret the features, one probably needs to check a LearningTask.  It
// provides a description for each index.  For example, [0]=="height",
// [1]=="url", etc.
using FeatureVector = std::vector<FeatureValue>;

using WeightType = size_t;

// One training example == group of feature values, plus the desired target.
struct COMPONENT_EXPORT(LEARNING_COMMON) LabelledExample {
  LabelledExample();
  LabelledExample(FeatureVector feature_vector, TargetValue target);
  LabelledExample(std::initializer_list<FeatureValue> init_list,
                  TargetValue target);
  LabelledExample(const LabelledExample& rhs);
  LabelledExample(LabelledExample&& rhs) noexcept;
  ~LabelledExample();

  // Comparisons ignore weight, because it's convenient.
  bool operator==(const LabelledExample& rhs) const;
  bool operator!=(const LabelledExample& rhs) const;
  bool operator<(const LabelledExample& rhs) const;

  LabelledExample& operator=(const LabelledExample& rhs);
  LabelledExample& operator=(LabelledExample&& rhs) noexcept;

  // Observed feature values.
  // Note that to interpret these values, you probably need to have the
  // LearningTask that they're supposed to be used with.
  FeatureVector features;

  // Observed output value, when given |features| as input.
  TargetValue target_value;

  WeightType weight = 1u;

  // Copy / assignment is allowed.
};

// TODO(liberato): This should probably move to impl/ .
class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
 public:
  using ExampleVector = std::vector<LabelledExample>;
  using const_iterator = ExampleVector::const_iterator;

  TrainingData();
  TrainingData(const TrainingData& rhs);
  TrainingData(TrainingData&& rhs);

  TrainingData& operator=(const TrainingData& rhs);
  TrainingData& operator=(TrainingData&& rhs);

  ~TrainingData();

  // Add |example| with weight |weight|.
  void push_back(const LabelledExample& example) {
    DCHECK_GT(example.weight, 0u);
    examples_.push_back(example);
    total_weight_ += example.weight;
  }

  bool empty() const { return !total_weight_; }

  size_t size() const { return examples_.size(); }

  // Returns the number of instances, taking into account their weight.  For
  // example, if one adds an example with weight 2, then this will return two
  // more than it did before.
  WeightType total_weight() const { return total_weight_; }

  const_iterator begin() const { return examples_.begin(); }
  const_iterator end() const { return examples_.end(); }

  bool is_unweighted() const { return examples_.size() == total_weight_; }

  // Provide the |i|-th example, over [0, size()).
  const LabelledExample& operator[](size_t i) const { return examples_[i]; }
  LabelledExample& operator[](size_t i) { return examples_[i]; }

  // Return a copy of this data with duplicate entries merged.  Example weights
  // will be summed.
  TrainingData DeDuplicate() const;

 private:
  ExampleVector examples_;

  WeightType total_weight_ = 0u;

  // Copy / assignment is allowed.
};

COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const LabelledExample& example);

COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const FeatureVector& features);

}  // namespace learning
}  // namespace media

#endif  // MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_