aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/http2_test/http2_test_server.py
blob: be5f1593ebcf0dca72f72250d9dd1f9c3dec37b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
  HTTP2 Test Server. Highly experimental work in progress.
"""
import struct
import messages_pb2
import argparse
import logging
import time

from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.protocol import Protocol, Factory
from twisted.internet import endpoints, reactor, error, defer
from h2.connection import H2Connection
from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged
from threading import Lock
import http2_base_server

READ_CHUNK_SIZE = 16384
GRPC_HEADER_SIZE = 5

class TestcaseRstStreamAfterHeader(object):
  def __init__(self):
    self._base_server = http2_base_server.H2ProtocolBaseServer()
    self._base_server._handlers['RequestReceived'] = self.on_request_received

  def get_base_server(self):
    return self._base_server

  def on_request_received(self, event):
    # send initial headers
    self._base_server.on_request_received_default(event)
    # send reset stream
    self._base_server.send_reset_stream()

class TestcaseRstStreamAfterData(object):
  def __init__(self):
    self._base_server = http2_base_server.H2ProtocolBaseServer()
    self._base_server._handlers['DataReceived'] = self.on_data_received

  def get_base_server(self):
    return self._base_server

  def on_data_received(self, event):
    self._base_server.on_data_received_default(event)
    sr = self._base_server.parse_received_data(self._base_server._recv_buffer)
    assert(sr is not None)
    assert(sr.response_size <= 2048) # so it can fit into one flow control window
    response_data = self._base_server.default_response_data(sr.response_size)
    self._ready_to_send = True
    self._base_server.setup_send(response_data)
    # send reset stream
    self._base_server.send_reset_stream()

class TestcaseGoaway(object):
  """ 
    Process incoming request normally. After sending trailer response,
    send GOAWAY with stream id = 1.
    assert that the next request is made on a different connection.
  """
  def __init__(self, iteration):
    self._base_server = http2_base_server.H2ProtocolBaseServer()
    self._base_server._handlers['RequestReceived'] = self.on_request_received
    self._base_server._handlers['DataReceived'] = self.on_data_received
    self._base_server._handlers['WindowUpdated'] = self.on_window_update_default
    self._base_server._handlers['SendDone'] = self.on_send_done
    self._base_server._handlers['ConnectionLost'] = self.on_connection_lost
    self._ready_to_send = False
    self._iteration = iteration

  def get_base_server(self):
    return self._base_server

  def on_connection_lost(self, reason):
    logging.info('Disconnect received. Count %d'%self._iteration)
    # _iteration == 2 => Two different connections have been used.
    if self._iteration == 2:
      self._base_server.on_connection_lost(reason)

  def on_send_done(self):
    self._base_server.on_send_done_default()
    if self._base_server._stream_id == 1:
      logging.info('Sending GOAWAY for stream 1')
      self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=1)

  def on_request_received(self, event):
    self._ready_to_send = False
    self._base_server.on_request_received_default(event)

  def on_data_received(self, event):
    self._base_server.on_data_received_default(event)
    sr = self._base_server.parse_received_data(self._base_server._recv_buffer)
    if sr:
      time.sleep(1)
      logging.info('Creating response size = %s'%sr.response_size)
      response_data = self._base_server.default_response_data(sr.response_size)
      self._ready_to_send = True
      self._base_server.setup_send(response_data)

  def on_window_update_default(self, event):
    if self._ready_to_send:
      self._base_server.default_send()

class TestcasePing(object):
  """ 
  """
  def __init__(self, iteration):
    self._base_server = http2_base_server.H2ProtocolBaseServer()
    self._base_server._handlers['RequestReceived'] = self.on_request_received
    self._base_server._handlers['DataReceived'] = self.on_data_received
    self._base_server._handlers['ConnectionLost'] = self.on_connection_lost

  def get_base_server(self):
    return self._base_server

  def on_request_received(self, event):
    self._base_server.default_ping()
    self._base_server.on_request_received_default(event)
    self._base_server.default_ping()

  def on_data_received(self, event):
    self._base_server.on_data_received_default(event)
    sr = self._base_server.parse_received_data(self._base_server._recv_buffer)
    logging.info('Creating response size = %s'%sr.response_size)
    response_data = self._base_server.default_response_data(sr.response_size)
    self._base_server.default_ping()
    self._base_server.setup_send(response_data)
    self._base_server.default_ping()

  def on_connection_lost(self, reason):
    logging.info('Disconnect received. Ping Count %d'%self._base_server._outstanding_pings)
    assert(self._base_server._outstanding_pings == 0)
    self._base_server.on_connection_lost(reason)

class H2Factory(Factory):
  def __init__(self, testcase):
    logging.info('In H2Factory')
    self._num_streams = 0
    self._testcase = testcase

  def buildProtocol(self, addr):
    self._num_streams += 1
    if self._testcase == 'rst_stream_after_header':
      t = TestcaseRstStreamAfterHeader(self._num_streams)
    elif self._testcase == 'rst_stream_after_data':
      t = TestcaseRstStreamAfterData(self._num_streams)
    elif self._testcase == 'goaway':
      t = TestcaseGoaway(self._num_streams)
    elif self._testcase == 'ping':
      t = TestcasePing(self._num_streams)
    else:
      assert(0)
    return t.get_base_server()

if __name__ == "__main__":
  logging.basicConfig(format = "%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s", level=logging.INFO)
  parser = argparse.ArgumentParser()
  parser.add_argument("test")
  parser.add_argument("port")
  args = parser.parse_args()
  if args.test not in ['rst_stream_after_header', 'rst_stream_after_data', 'goaway', 'ping']:
    print 'unknown test: ', args.test
  endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128)
  endpoint.listen(H2Factory(args.test))
  reactor.run()