1
    2
    3
    4
    5
    6
    7
    8
    9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   21
   22
   23
   24
   25
   26
   27
   28
   29
   30
   31
   32
   33
   34
   35
   36
   37
   38
   39
   40
   41
   42
   43
   44
   45
   46
   47
   48
   49
   50
   51
   52
   53
   54
   55
   56
   57
   58
   59
   60
   61
   62
   63
   64
   65
   66
   67
   68
   69
   70
   71
   72
   73
   74
   75
   76
   77
   78
   79
   80
   81
   82
   83
   84
   85
   86
   87
   88
   89
   90
   91
   92
   93
   94
   95
   96
   97
   98
   99
  100
  101
  102
  103
  104
  105
  106
  107
  108
  109
  110
  111
  112
  113
  114
  115
  116
  117
  118
  119
  120
  121

media / learning / common / target_histogram.cc [blame]

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "media/learning/common/target_histogram.h"

#include <sstream>

namespace media {
namespace learning {

TargetHistogram::TargetHistogram() = default;

TargetHistogram::TargetHistogram(const TargetHistogram& rhs) = default;

TargetHistogram::TargetHistogram(TargetHistogram&& rhs) = default;

TargetHistogram::~TargetHistogram() = default;

TargetHistogram& TargetHistogram::operator=(const TargetHistogram& rhs) =
    default;

TargetHistogram& TargetHistogram::operator=(TargetHistogram&& rhs) = default;

bool TargetHistogram::operator==(const TargetHistogram& rhs) const {
  return rhs.total_counts() == total_counts() && rhs.counts_ == counts_;
}

TargetHistogram& TargetHistogram::operator+=(const TargetHistogram& rhs) {
  for (auto& rhs_pair : rhs.counts())
    counts_[rhs_pair.first] += rhs_pair.second;

  return *this;
}

TargetHistogram& TargetHistogram::operator+=(const TargetValue& rhs) {
  counts_[rhs]++;
  return *this;
}

TargetHistogram& TargetHistogram::operator+=(const LabelledExample& example) {
  counts_[example.target_value] += example.weight;
  return *this;
}

double TargetHistogram::operator[](const TargetValue& value) const {
  auto iter = counts_.find(value);
  if (iter == counts_.end())
    return 0;

  return iter->second;
}

double& TargetHistogram::operator[](const TargetValue& value) {
  return counts_[value];
}

bool TargetHistogram::FindSingularMax(TargetValue* value_out,
                                      double* counts_out) const {
  if (!counts_.size())
    return false;

  double unused_counts;
  if (!counts_out)
    counts_out = &unused_counts;

  auto iter = counts_.begin();
  *value_out = iter->first;
  *counts_out = iter->second;
  bool singular_max = true;
  for (iter++; iter != counts_.end(); iter++) {
    if (iter->second > *counts_out) {
      *value_out = iter->first;
      *counts_out = iter->second;
      singular_max = true;
    } else if (iter->second == *counts_out) {
      // If this turns out to be the max, then it's not singular.
      singular_max = false;
    }
  }

  return singular_max;
}

double TargetHistogram::Average() const {
  double total_value = 0.;
  double total_counts = 0;
  for (auto& iter : counts_) {
    total_value += iter.first.value() * iter.second;
    total_counts += iter.second;
  }

  if (!total_counts)
    return 0.;

  return total_value / total_counts;
}

void TargetHistogram::Normalize() {
  double total = total_counts();
  for (auto& iter : counts_)
    iter.second /= total;
}

std::string TargetHistogram::ToString() const {
  std::ostringstream ss;
  ss << "[";
  for (auto& entry : counts_)
    ss << " " << entry.first << ":" << entry.second;
  ss << " ]";

  return ss.str();
}

std::ostream& operator<<(std::ostream& out,
                         const media::learning::TargetHistogram& dist) {
  return out << dist.ToString();
}

}  // namespace learning
}  // namespace media