1
    2
    3
    4
    5
    6
    7
    8
    9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   21
   22
   23
   24
   25
   26
   27
   28
   29
   30
   31
   32
   33
   34
   35
   36
   37
   38
   39
   40
   41
   42
   43
   44
   45
   46
   47
   48
   49
   50
   51
   52
   53
   54
   55
   56
   57
   58
   59
   60
   61
   62
   63
   64
   65
   66
   67
   68
   69
   70
   71
   72
   73
   74
   75
   76
   77
   78
   79
   80
   81
   82
   83
   84
   85
   86
   87
   88
   89
   90
   91
   92
   93
   94
   95
   96
   97
   98
   99
  100
  101
  102
  103
  104
  105
  106
  107
  108
  109
  110
  111
  112
  113
  114
  115
  116
  117
  118
  119
  120
  121
  122
  123
  124
  125
  126
  127
  128
  129
  130
  131
  132
  133
  134

content / test / gpu / gpu_tests / util / websocket_server.py [blame]

# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Code to allow tests to communicate via a websocket server."""

import logging
import threading
from typing import Optional

import websockets  # pylint: disable=import-error
import websockets.sync.server as sync_server  # pylint: disable=import-error

WEBSOCKET_PORT_TIMEOUT_SECONDS = 10
WEBSOCKET_SETUP_TIMEOUT_SECONDS = 5
WEBSOCKET_CLOSE_TIMEOUT_SECONDS = 2
SERVER_SHUTDOWN_TIMEOUT_SECONDS = 5

# The client (Chrome) should never be closing the connection. If it does, it's
# indicative of something going wrong like a renderer crash.
ClientClosedConnectionError = websockets.exceptions.ConnectionClosedOK

# Alias for readability.
WebsocketReceiveMessageTimeoutError = TimeoutError


class WebsocketServer():

  def __init__(self):
    """Server that abstracts the websocket library under the hood.

    Only supports one active connection at a time.
    """
    self.server_port = None
    self.websocket = None
    self.connection_stopper_event = None
    self.connection_closed_event = None
    self.port_set_event = threading.Event()
    self.connection_received_event = threading.Event()
    self._server_thread = None

  def StartServer(self) -> None:
    """Starts the websocket server on a separate thread."""
    assert self._server_thread is None, 'Server already running'
    self._server_thread = _ServerThread(self)
    self._server_thread.daemon = True
    self._server_thread.start()
    got_port = self.port_set_event.wait(WEBSOCKET_PORT_TIMEOUT_SECONDS)
    if not got_port:
      raise RuntimeError('Websocket server did not provide a port')
    # Note: We don't need to set up any port forwarding for remote platforms
    # after this point due to Telemetry's use of --proxy-server to send all
    # traffic through the TsProxyServer. This causes network traffic to pop out
    # on the host, which means that using the websocket server's port directly
    # works.

  def ClearCurrentConnection(self) -> None:
    if self.connection_stopper_event:
      self.connection_stopper_event.set()
      closed = self.connection_closed_event.wait(
          WEBSOCKET_CLOSE_TIMEOUT_SECONDS)
      if not closed:
        raise RuntimeError('Websocket connection did not close')
    self.connection_stopper_event = None
    self.connection_closed_event = None
    self.websocket = None
    self.connection_received_event.clear()

  def WaitForConnection(self, timeout: Optional[float] = None) -> None:
    if self.websocket:
      return
    timeout = timeout or WEBSOCKET_SETUP_TIMEOUT_SECONDS
    self.connection_received_event.wait(timeout)
    if not self.websocket:
      raise RuntimeError('Websocket connection was not established')

  def StopServer(self) -> None:
    self.ClearCurrentConnection()
    self._server_thread.shutdown()
    self._server_thread.join(SERVER_SHUTDOWN_TIMEOUT_SECONDS)
    if self._server_thread.is_alive():
      logging.error(
          'Websocket server did not shut down properly - this might be '
          'indicative of an issue in the test harness')

  def Send(self, message: str) -> None:
    self.websocket.send(message)

  def Receive(self, timeout: int) -> str:
    try:
      return self.websocket.recv(timeout)
    except TimeoutError as e:
      raise WebsocketReceiveMessageTimeoutError(
          'Timed out after %d seconds waiting for websocket message' %
          timeout) from e


class _ServerThread(threading.Thread):
  def __init__(self, server_instance: WebsocketServer, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._server_instance = server_instance
    self.websocket_server = None

  def run(self) -> None:
    StartWebsocketServer(self, self._server_instance)

  def shutdown(self) -> None:
    self.websocket_server.shutdown()


def StartWebsocketServer(server_thread: _ServerThread,
                         server_instance: WebsocketServer) -> None:
  def HandleWebsocketConnection(
      websocket: sync_server.ServerConnection) -> None:
    # We only allow one active connection - if there are multiple, something is
    # wrong.
    assert server_instance.connection_stopper_event is None
    assert server_instance.connection_closed_event is None
    assert server_instance.websocket is None
    server_instance.connection_stopper_event = threading.Event()
    server_instance.connection_closed_event = threading.Event()
    # Keep our own reference in case the server clears its reference before the
    # await finishes.
    connection_stopper_event = server_instance.connection_stopper_event
    connection_closed_event = server_instance.connection_closed_event
    server_instance.websocket = websocket
    server_instance.connection_received_event.set()
    connection_stopper_event.wait()
    connection_closed_event.set()

  with sync_server.serve(HandleWebsocketConnection, '127.0.0.1', 0) as server:
    server_thread.websocket_server = server
    server_instance.server_port = server.socket.getsockname()[1]
    server_instance.port_set_event.set()
    server.serve_forever()