1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
media / learning / common / learning_task_controller.h [blame]
// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_COMMON_LEARNING_TASK_CONTROLLER_H_
#define MEDIA_LEARNING_COMMON_LEARNING_TASK_CONTROLLER_H_
#include <optional>
#include "base/component_export.h"
#include "base/functional/callback.h"
#include "base/unguessable_token.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/target_histogram.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
namespace media {
namespace learning {
// Wrapper struct for completing an observation via LearningTaskController.
// Most callers will just send in a TargetValue, so this lets us provide a
// default weight. Further, a few callers will add optional data, like the UKM
// SourceId, which most callers don't care about.
struct ObservationCompletion {
ObservationCompletion() = default;
/* implicit */ ObservationCompletion(const TargetValue& target,
WeightType w = 1.)
: target_value(target), weight(w) {}
TargetValue target_value;
WeightType weight;
// Mostly for gmock matchers.
bool operator==(const ObservationCompletion& rhs) const {
return target_value == rhs.target_value && weight == rhs.weight;
}
};
// Client for a single learning task. Intended to be the primary API for client
// code that generates FeatureVectors / requests predictions for a single task.
// The API supports sending in an observed FeatureVector without a target value,
// so that framework-provided features (FeatureProvider) can be snapshotted at
// the right time. One doesn't generally want to wait until the TargetValue is
// observed to do that.
class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
public:
using PredictionCB =
base::OnceCallback<void(const std::optional<TargetHistogram>& predicted)>;
LearningTaskController() = default;
LearningTaskController(const LearningTaskController&) = delete;
LearningTaskController& operator=(const LearningTaskController&) = delete;
virtual ~LearningTaskController() = default;
// Start a new observation. Call this at the time one would try to predict
// the TargetValue. This lets the framework snapshot any framework-provided
// feature values at prediction time. Later, if you want to turn these
// features into an example for training a model, then call
// CompleteObservation with the same id and an ObservationCompletion.
// Otherwise, call CancelObservation with |id|. It's also okay to destroy the
// controller with outstanding observations; these will be cancelled if no
// |default_target| was specified, or completed with |default_target|.
//
// TODO(liberato): This should optionally take a callback to receive a
// prediction for the FeatureVector.
// TODO(liberato): See if this ends up generating smaller code with pass-by-
// value or with |FeatureVector&&|, once we have callers that can actually
// benefit from it.
virtual void BeginObservation(
base::UnguessableToken id,
const FeatureVector& features,
const std::optional<TargetValue>& default_target = std::nullopt,
const std::optional<ukm::SourceId>& source_id = std::nullopt) = 0;
// Complete an observation by sending a completion.
virtual void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) = 0;
// Notify the LearningTaskController that no completion will be sent.
virtual void CancelObservation(base::UnguessableToken id) = 0;
// Update the default target value for |id|. This can change a previously
// specified default value to something else, add one where one wasn't
// specified before, or un-set it. In the last case, the observation will be
// cancelled rather than completed if |this| is destroyed, just as if no
// default value was given.
virtual void UpdateDefaultTarget(
base::UnguessableToken id,
const std::optional<TargetValue>& default_target) = 0;
// Returns the LearningTask associated with |this|.
virtual const LearningTask& GetLearningTask() = 0;
// Asynchronously predicts distribution for given |features|. |callback| will
// receive a std::nullopt prediction when model is not available. |callback|
// may be called immediately without posting.
virtual void PredictDistribution(const FeatureVector& features,
PredictionCB callback) = 0;
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_COMMON_LEARNING_TASK_CONTROLLER_H_