1
    2
    3
    4
    5
    6
    7
    8
    9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   21
   22
   23
   24
   25
   26
   27
   28
   29
   30
   31
   32
   33
   34
   35
   36
   37
   38
   39
   40
   41
   42
   43
   44
   45
   46
   47
   48
   49
   50
   51
   52
   53
   54
   55
   56
   57
   58
   59
   60
   61
   62
   63
   64
   65
   66
   67
   68
   69
   70
   71
   72
   73
   74
   75
   76
   77
   78
   79
   80
   81
   82
   83
   84
   85
   86
   87
   88
   89
   90
   91
   92
   93
   94
   95
   96
   97
   98
   99
  100
  101
  102
  103
  104
  105
  106
  107
  108
  109
  110
  111
  112
  113
  114
  115
  116
  117
  118
  119
  120
  121
  122
  123
  124
  125
  126
  127
  128
  129
  130
  131
  132

content / browser / aggregation_service / aggregation_service_key_fetcher.cc [blame]

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "content/browser/aggregation_service/aggregation_service_key_fetcher.h"

#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "base/containers/circular_deque.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/rand_util.h"
#include "content/browser/aggregation_service/aggregation_service_storage.h"
#include "content/browser/aggregation_service/aggregation_service_storage_context.h"
#include "services/network/public/cpp/is_potentially_trustworthy.h"
#include "url/gurl.h"

namespace content {

AggregationServiceKeyFetcher::AggregationServiceKeyFetcher(
    AggregationServiceStorageContext* storage_context,
    std::unique_ptr<NetworkFetcher> network_fetcher)
    : storage_context_(storage_context),
      network_fetcher_(std::move(network_fetcher)) {}

AggregationServiceKeyFetcher::~AggregationServiceKeyFetcher() = default;

void AggregationServiceKeyFetcher::GetPublicKey(const GURL& url,
                                                FetchCallback callback) {
  CHECK(network::IsUrlPotentiallyTrustworthy(url));

  base::circular_deque<FetchCallback>& pending_callbacks = url_callbacks_[url];
  bool in_progress = !pending_callbacks.empty();
  pending_callbacks.push_back(std::move(callback));

  // If there is already a fetch request in progress, just enqueue the
  // callback and return.
  if (in_progress)
    return;

  // First we check if we already have keys stored.
  // TODO(crbug.com/40187645): Pass url by value and move after C++17.
  storage_context_->GetStorage()
      .AsyncCall(&AggregationServiceStorage::GetPublicKeys)
      .WithArgs(url)
      .Then(base::BindOnce(
          &AggregationServiceKeyFetcher::OnPublicKeysReceivedFromStorage,
          weak_factory_.GetWeakPtr(), url));
}

void AggregationServiceKeyFetcher::OnPublicKeysReceivedFromStorage(
    const GURL& url,
    std::vector<PublicKey> keys) {
  if (keys.empty()) {
    // Fetch keys from the network if not found in the storage.
    FetchPublicKeysFromNetwork(url);
    return;
  }

  RunCallbacksForUrl(url, keys);
}

void AggregationServiceKeyFetcher::FetchPublicKeysFromNetwork(const GURL& url) {
  if (!network_fetcher_) {
    // Return error if fetching from network is not enabled.
    RunCallbacksForUrl(url, /*keys=*/{});
    return;
  }

  // Unretained is safe because the network fetcher is owned by `this` and will
  // be deleted before `this`.
  network_fetcher_->FetchPublicKeys(
      url, base::BindOnce(
               &AggregationServiceKeyFetcher::OnPublicKeysReceivedFromNetwork,
               base::Unretained(this), url));
}

void AggregationServiceKeyFetcher::OnPublicKeysReceivedFromNetwork(
    const GURL& url,
    std::optional<PublicKeyset> keyset) {
  if (!keyset.has_value() || keyset->expiry_time.is_null()) {
    // `keyset` will be std::nullopt if an error occurred and `expiry_time`
    // will be null if the freshness lifetime was zero. In these cases, we will
    // still update the keys for `url`, i,e. clear them.
    storage_context_->GetStorage()
        .AsyncCall(&AggregationServiceStorage::ClearPublicKeys)
        .WithArgs(url);
  } else {
    // Store public keys fetched from network to storage, the old keys will be
    // deleted from storage.
    storage_context_->GetStorage()
        .AsyncCall(&AggregationServiceStorage::SetPublicKeys)
        .WithArgs(url, keyset.value());
  }

  RunCallbacksForUrl(
      url, keyset.has_value() ? keyset->keys : std::vector<PublicKey>());
}

void AggregationServiceKeyFetcher::RunCallbacksForUrl(
    const GURL& url,
    const std::vector<PublicKey>& keys) {
  auto iter = url_callbacks_.find(url);
  CHECK(iter != url_callbacks_.end());

  base::circular_deque<FetchCallback> pending_callbacks =
      std::move(iter->second);
  CHECK(!pending_callbacks.empty());

  url_callbacks_.erase(iter);

  if (keys.empty()) {
    // Return error, don't refetch to avoid infinite loop.
    for (auto& callback : pending_callbacks) {
      std::move(callback).Run(std::nullopt,
                              PublicKeyFetchStatus::kPublicKeyFetchFailed);
    }
  } else {
    for (auto& callback : pending_callbacks) {
      // Each report should randomly select a key. This ensures that the set of
      // reports a client sends are not a subset of the reports identified by
      // any one key.
      uint64_t key_index = base::RandGenerator(keys.size());
      std::move(callback).Run(keys[key_index], PublicKeyFetchStatus::kOk);
    }
  }
}

}  // namespace content