1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
ash / components / kcer / token_results_merger.h [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef ASH_COMPONENTS_KCER_TOKEN_RESULTS_MERGER_H_
#define ASH_COMPONENTS_KCER_TOKEN_RESULTS_MERGER_H_
#include <type_traits>
#include <vector>
#include "ash/components/kcer/kcer.h"
#include "base/containers/contains.h"
#include "base/containers/flat_map.h"
#include "base/functional/callback.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_refptr.h"
#include "base/types/expected.h"
#include "base/types/pass_key.h"
namespace kcer::internal {
// This type is used for unit tests.
struct MoveOnlyType;
// A helper class for Kcer methods that work with several tokens in parallel.
// Collects and aggregates results from each token before returning the end
// result.
template <typename T>
class TokenResultsMerger : public base::RefCounted<TokenResultsMerger<T>> {
public:
static_assert(std::is_same<T, scoped_refptr<const Cert>>::value ||
std::is_same<T, PublicKey>::value ||
std::is_same<T, MoveOnlyType>::value);
// `results_to_receive` is the amount of results to collect before
// aggregating and returning them.
static scoped_refptr<TokenResultsMerger> Create(
int results_to_receive,
base::OnceCallback<void(std::vector<T>, base::flat_map<Token, Error>)>
callback);
// Use Create() instead.
TokenResultsMerger(
base::PassKey<TokenResultsMerger<T>>,
int results_to_receive,
base::OnceCallback<void(std::vector<T>, base::flat_map<Token, Error>)>
callback)
: callbacks_to_create_(results_to_receive),
results_to_receive_(results_to_receive),
callback_(std::move(callback)) {
CHECK_GT(results_to_receive_, 0);
}
// Returns a callback to collect one result from `token`, the callback
// will co-own `this` instance.
base::OnceCallback<void(base::expected<std::vector<T>, Error>)> GetCallback(
Token token);
private:
friend class base::RefCounted<TokenResultsMerger<T>>;
~TokenResultsMerger() {
// If `callbacks_to_create_` is positive, then not enough callbacks were
// created and the result was never returned. If negative, then the result
// was returned before receiving all sub-results.
CHECK_EQ(callbacks_to_create_, 0);
}
private:
void HandleOneResult(Token token,
base::expected<std::vector<T>, Error> result);
// Guardrail variable to ensure that the merger is used correctly.
int callbacks_to_create_ = 0;
// Counter for how many results should still be received.
int results_to_receive_ = 0;
// Callback for the end result.
base::OnceCallback<void(std::vector<T>, base::flat_map<Token, Error>)>
callback_;
// Objects from succeeded tokens.
std::vector<T> good_results_;
// Errors from failed tokens.
base::flat_map<Token, Error> errors_;
};
// static
template <typename T>
scoped_refptr<TokenResultsMerger<T>> TokenResultsMerger<T>::Create(
int results_to_receive,
base::OnceCallback<void(std::vector<T>, base::flat_map<Token, Error>)>
callback) {
return base::MakeRefCounted<TokenResultsMerger<T>>(
base::PassKey<TokenResultsMerger<T>>(), results_to_receive,
std::move(callback));
}
template <typename T>
base::OnceCallback<void(base::expected<std::vector<T>, Error>)>
TokenResultsMerger<T>::GetCallback(Token token) {
this->callbacks_to_create_--;
CHECK_GE(this->callbacks_to_create_, 0);
return base::BindOnce(&TokenResultsMerger<T>::HandleOneResult,
base::RetainedRef(this), token);
}
template <typename T>
void TokenResultsMerger<T>::HandleOneResult(
Token token,
base::expected<std::vector<T>, Error> result) {
if (result.has_value()) {
good_results_.reserve(good_results_.size() + result.value().size());
std::move(result.value().begin(), result.value().end(),
std::back_inserter(good_results_));
} else {
CHECK(!base::Contains(errors_, token));
errors_[token] = result.error();
}
--results_to_receive_;
if (results_to_receive_ == 0) {
std::move(callback_).Run(std::move(good_results_), std::move(errors_));
}
}
} // namespace kcer::internal
#endif // ASH_COMPONENTS_KCER_TOKEN_RESULTS_MERGER_H_